main 9bea856c911e cached
44 files
359.0 KB
101.9k tokens
157 symbols
1 requests
Download .txt
Showing preview only (376K chars total). Download the full file or copy to clipboard to get everything.
Repository: XunhaoLai/native-sparse-attention-triton
Branch: main
Commit: 9bea856c911e
Files: 44
Total size: 359.0 KB

Directory structure:
gitextract_5vrtmrhk/

├── .gitignore
├── LICENSE
├── README.md
├── install_dependency.sh
├── native_sparse_attention/
│   ├── __init__.py
│   ├── infer/
│   │   ├── __init__.py
│   │   ├── inference_func.py
│   │   └── nsa_inference.py
│   ├── model/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── toy_llama.py
│   │   └── toy_nsa_llama.py
│   ├── module/
│   │   ├── __init__.py
│   │   ├── kv_cache.py
│   │   ├── native_sparse_attention.py
│   │   ├── rope.py
│   │   └── self_attention.py
│   └── ops/
│       ├── README.md
│       ├── __init__.py
│       ├── torch/
│       │   ├── __init__.py
│       │   ├── compress_key_value.py
│       │   ├── compressed_attention.py
│       │   ├── compressed_attention_decode.py
│       │   └── topk_sparse_attention.py
│       └── triton/
│           ├── __init__.py
│           ├── compressed_attention.py
│           ├── flash_attention.py
│           ├── flash_attention_decode.py
│           ├── linear_compress.py
│           ├── topk_sparse_attention.py
│           ├── topk_sparse_attention_decode.py
│           ├── utils.py
│           └── weighted_pool.py
├── setup.py
└── test/
    ├── test_compress_key_value.py
    ├── test_compressed_attention.py
    ├── test_flash_attention.py
    ├── test_kv_cache.py
    ├── test_linear_compress.py
    ├── test_nsa_infer.py
    ├── test_nsa_model.py
    ├── test_nsa_module.py
    ├── test_rope.py
    └── test_topk_sparse_attention.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# UV
#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#uv.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# PyPI configuration file
.pypirc


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">

# Native Sparse Attention Triton

</div>

This repository implements the sparse attention mechanism introduced in the paper [Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention](https://arxiv.org/abs/2502.11089) and provides an efficient training implementation based on [Triton](https://github.com/triton-lang/triton).

🎉 We now support both training and inference for Native Sparse Attention (variable-length version, including prefilling, decoding, and KV cache management). We have provided a toy model at `model.ToyNSALlama`, which supports `forward` function for training and `generate` function for inference. Welcome to try it out!

## Requirements
Ensure the following dependencies are installed:
- PyTorch >= 2.1.0
- triton >= 3.0.0
- einops >= 0.7.0
- flash_attn >= 2.6.3

## Usage

### Notes
1. PyTorch implementations (`ops.torch`) are intended for debugging only.
2. For production use, prefer Triton operators (`ops.triton`).
3. All implementations are based on the varlen approach similiar to flash_attn_func_varlen. Please concatenate the inputs of a batch before use.
4. Only support attention head dimension less than 128 for now.

### Install

You can install `native_sparse_attention` using pip:

```shell
pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git
```

### Functions

The `ops` module has implemented several functions required for native sparse attention. For detailed usage instructions, please see [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme).

You can import those functions from the `ops` module:

```python
import torch
from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention

# input example
num_q_heads = 64
num_kv_heads = 4
head_dim = 128
kernel_size = 32
kernel_stride = 16
block_size = 64
topk = 16
cu_seqlens = torch.Tensor([0, 1024, 8192, 16384]).to(torch.int32).cuda()
query = torch.randn(16384, num_q_heads, head_dim).to(torch.bfloat16).cuda()
key = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
value = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()

# weight example
w = (
    torch.randn(num_kv_heads, kernel_size * head_dim, head_dim)
    .to(torch.bfloat16)
    .cuda()
)
pe = torch.randn(num_kv_heads, kernel_size, head_dim).to(torch.bfloat16).cuda()

# 1. key value compression
compressed_key, compressed_cu_seqlens = linear_compress(
    key, w, cu_seqlens, kernel_size, kernel_stride, pe
)
compressed_value, _ = linear_compress(
    value, w, cu_seqlens, kernel_size, kernel_stride, None
)

# 2. attention between query and compressed key value
compressed_attn_output, topk_idx = compressed_attention(
    query,
    compressed_key,
    compressed_value,
    kernel_size,
    kernel_stride,
    block_size,
    topk,
    cu_seqlens,
    compressed_cu_seqlens,
    init_blocks=1,
    local_blocks=2,
)

# 3. topk sparse attention
sparse_attn_output = topk_sparse_attention(
    query,
    key,
    value,
    topk_idx,
    block_size,
    cu_seqlens,
)
```

### Module

The `modules` directory also provides implementations based on `torch.nn.module` for easy integration into models.

```python
from native_sparse_attention.modules import NativeSparseAttention, RopeConfig

NSA_Layer = NativeSparseAttention(
    compress_type="linear",
    hidden_size=4096,
    num_q_heads=64,
    num_kv_heads=4,
    head_dim=128,
    kernel_size=32,
    kernel_stride=16,
    block_size=64,
    topk=8,
    init_blocks=1,
    local_blocks=2,
    window_size=512,
    rope_config=RopeConfig(
        max_position_embeddings=32768,
        head_dim=128,
        rope_theta=500000,
        rope_scaling={
            "factor": 4.0,
            "high_freq_factor": 4.0,
            "low_freq_factor": 1.0,
            "original_max_position_embeddings": 8192,
            "rope_type": "llama3",
        },
    ),
)
```

### Model

We offer two simplified LLaMA models in the `model` directory, featuring self-attention and native sparse attention. For more details on how to use these models, please refer to [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/model#readme).


```python
from native_sparse_attention.model import ToyNSALlamaConfig, InferenceConfig, ToyNSALlama

config = ToyNSALlamaConfig(
    hidden_size=4096,
    intermediate_size=14336,
    num_hidden_layers=8,
    num_attention_heads=32,
    num_key_value_heads=2,
    head_dim=128,
    rope_theta=500000.0,
    rope_scaling={
        "factor": 8.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    compress_type="weightedpool",
    kernel_size=32,
    kernel_stride=16,
    block_size=64,
    topk=8,
    init_blocks=1,
    local_blocks=2,
    window_size=512,
)
inference_config = InferenceConfig(
    max_batch_size=4,
    max_length=8192,
    max_new_tokens=128,
)
model = ToyNSALlama(config, inference_config).cuda().bfloat16()
```

## Testing

Some test scripts are available in the `test` folder and can be run directly for unit testing. For example:

```bash
python test/test_topk_sparse_attention.py
python test/test_nsa_module.py
python test/test_nsa_model.py
```

### Benchmarks

Here are the speed benchmarks conducted on a single NVIDIA A100 GPU or H100 GPU for the `topk_sparse_attention` function: 

A100 GPU speed benchmarks:
```sh
** forward with block size 64 **:
          N       Flash  Triton-Flash  Triton-Top8  Triton-Top16
0    2048.0     0.414144      0.635648     0.633440      1.009184
1    4096.0     1.400304      2.267552     1.179808      1.916736
2    8192.0     5.223776      8.528160     2.266816      3.723168
3   16384.0    20.225697     32.745537     4.468128      7.359168
4   32768.0    79.587715    128.951065     8.517440     14.142848
5   65536.0   321.240479    511.652100    17.249599     30.991360
6  131072.0  1349.810425   2063.245605    36.400482     67.884544

** backward with block size 64 **:
          N        Flash  Triton-Flash  Triton-Top8  Triton-Top16
0    2048.0     1.315440      2.348560     1.941568      2.691040
1    4096.0     4.271584      8.553184     3.647744      5.032160
2    8192.0    15.323984     32.665440     5.650144      9.066112
3   16384.0    58.753281    127.675964    11.160832     17.113279
4   32768.0   227.770462    504.572693    21.723392     34.715614
5   65536.0   899.181274   2059.718506    44.517181     76.309441
6  131072.0  3587.918701   8530.726562   105.344734    182.970169
```

H100 GPU benchmarks:
```sh
** forward with block size 64 **:
          N       Flash  Triton-Flash  Triton-Top8  Triton-Top16
0    2048.0    0.259552      0.293888     0.584544      0.917664
1    4096.0    0.846848      1.029904     1.094976      1.745136
2    8192.0    3.043744      3.843392     2.128256      3.396880
3   16384.0   11.743568     14.791360     4.190528      6.704192
4   32768.0   45.968513     57.532478     7.614496     12.417440
5   65536.0  187.234375    228.093948    14.840048     24.511856
6  131072.0  810.890381    914.693970    29.470400     48.990192

** backward with block size 64 **:
          N        Flash  Triton-Flash  Triton-Top8  Triton-Top16
0    2048.0     0.798976      1.096096     1.117312      1.380016
1    4096.0     2.545680      3.826336     1.669760      2.214880
2    8192.0     9.029760     14.411633     2.772096      3.947456
3   16384.0    34.144016     58.945698     5.201344      7.538912
4   32768.0   135.718369    233.369247     9.968864     15.154192
5   65536.0   541.053894    929.337646    21.089870     33.818878
6  131072.0  2139.974854   3785.540527    54.918144     93.750717
```

Here comes another speed benchmark result for testing `compressed_attention` function on a single NVIDIA A100 GPU or H100 GPU:

A100 GPU speed benchmarks:
```sh
** forward with kernel 32 and stride 16 **:
          N       Flash  Triton-Flash  Compressed  Compressed-wo-Score
0    2048.0     0.413664      0.635488    0.655024             0.170816
1    4096.0     1.396416      2.247648    1.132304             0.377152
2    8192.0     5.234656      8.526400    2.879200             0.977952
3   16384.0    19.988865     32.755199    9.426448             2.943024
4   32768.0    79.419907    128.955170   30.284096             9.901120
5   65536.0   321.590210    511.615509  112.260544            36.001602
6  131072.0  1346.996338   2069.837891  423.099518           136.820038

** backward with kernel 32 and stride 16 **:
          N        Flash  Triton-Flash  Compressed
0    2048.0     1.322560      2.352000    0.486784
1    4096.0     4.270832      8.552608    0.971392
2    8192.0    15.515680     32.671329    2.603744
3   16384.0    59.345055    128.377472    8.499456
4   32768.0   230.626144    506.581238   30.064833
5   65536.0   919.260498   2068.642578  113.466560
6  131072.0  3646.603760   8498.374023  439.623444
```

H100 GPU speed benchmarks:
```sh
** forward with kernel 32 and stride 16 **:
          N       Flash  Triton-Flash  Compressed  Compressed-wo-Score
0    2048.0    0.259488      0.297152    0.485920             0.103232
1    4096.0    0.847376      1.030400    0.710208             0.217760
2    8192.0    3.044016      3.875840    1.607360             0.516016
3   16384.0   11.823104     14.829360    4.970272             1.440288
4   32768.0   46.204750     57.527809   15.004992             4.584736
5   65536.0  187.324249    227.909958   53.009087            16.134224
6  131072.0  810.707214    910.106873  191.245728            60.154270

** backward with kernel 32 and stride 16 **:
          N        Flash  Triton-Flash  Compressed
0    2048.0     0.797728      1.090640    0.283104
1    4096.0     2.547088      3.834592    0.550464
2    8192.0     9.021520     14.421088    1.249184
3   16384.0    34.159508     58.793377    3.743440
4   32768.0   136.844070    233.447708   12.640032
5   65536.0   537.559814    929.360229   46.054817
6  131072.0  2135.629883   3782.351562  175.587296
```

All the speed benchmarks above were tested with 64 query heads, 4 key/value heads, and a head dimension of 128.

## Contributing
Contributions are welcome! Please open an issue to discuss major changes.

## Contact

For any questions or feedback, please feel free to contact laixunhao@pku.edu.cn.

## Citations

```bibtex
@inproceedings{Yuan2025NativeSA,
    title   = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention},
    author  = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng},
    year    = {2025},
    url     = {https://api.semanticscholar.org/CorpusID:276408911}
}
```


================================================
FILE: install_dependency.sh
================================================
pip3 install packaging -i https://pypi.org/simple
pip3 install numpy==1.26.4 -i https://pypi.org/simple
pip3 install torch==2.4.0 -i https://pypi.org/simple
pip3 install triton==3.0.0 -i https://pypi.org/simple
pip3 install transformers==4.44.0 -i https://pypi.org/simple
pip3 install flash_attn==2.6.3 -i https://pypi.org/simple
pip3 install matplotlib==3.9.4 -i https://pypi.org/simple
pip3 install pandas==2.2.3 -i https://pypi.org/simple

================================================
FILE: native_sparse_attention/__init__.py
================================================


================================================
FILE: native_sparse_attention/infer/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from native_sparse_attention.infer.nsa_inference import nsa_infer

__all__ = [
    "nsa_infer",
]


================================================
FILE: native_sparse_attention/infer/inference_func.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Tuple, Callable, Optional
from flash_attn import flash_attn_varlen_func
from native_sparse_attention.ops import (
    flash_attention_decode,
    compressed_attention,
    compressed_attention_decode,
    topk_sparse_attention,
    topk_sparse_attention_decode,
)
from native_sparse_attention.ops.triton.utils import get_compressed_seqlens


def compress_infer(
    cu_seqlens: torch.Tensor,
    step: int,
    key: torch.Tensor,
    value: torch.Tensor,
    cache,
    weight: Tuple[torch.Tensor, torch.Tensor],
    compress_func: Tuple[Callable, Callable],
    intra_block_pe: Optional[torch.Tensor],
    kernel_size: int,
    kernel_stride: int,
):
    if step == 0:
        key, compress_cu_seqlens = compress_func[0](
            key,
            weight[0],
            cu_seqlens,
            kernel_size,
            kernel_stride,
            intra_block_pe,
        )
        value, _ = compress_func[1](
            value,
            weight[1],
            cu_seqlens,
            kernel_size,
            kernel_stride,
        )
    else:
        batch_size = cu_seqlens.shape[0] - 1
        aux_cu_seqlens = (
            torch.arange(batch_size + 1, dtype=torch.int32).to(cu_seqlens.device)
            * kernel_size
        )
        key, _ = compress_func[0](
            cache.before_compress_kv_cache[0, :batch_size].view(
                batch_size * kernel_size, cache.num_kv_heads, cache.head_dim
            ),
            weight[0],
            aux_cu_seqlens,
            kernel_size,
            kernel_stride,
            intra_block_pe,
        )
        value, _ = compress_func[1](
            cache.before_compress_kv_cache[1, :batch_size].view(
                batch_size * kernel_size, cache.num_kv_heads, cache.head_dim
            ),
            weight[1],
            aux_cu_seqlens,
            kernel_size,
            kernel_stride,
        )
        # return actual compress_cu_seqlens before this token
        compress_cu_seqlens = torch.zeros(
            batch_size + 1, dtype=torch.int32, device=key.device
        )
        compress_cu_seqlens[1:] = torch.cumsum(
            cache.compress_kv_len[:batch_size], dim=0
        )
    return key, value, compress_cu_seqlens


def compressed_attention_infer(
    cu_seqlens,
    step,
    query,
    key,
    value,
    cache,
    kernel_size,
    kernel_stride,
    topk,
    block_size,
    init_blocks,
    local_blocks,
):
    if step == 0:
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        compress_seqlens, compress_cu_seqlens = get_compressed_seqlens(
            cu_seqlens, kernel_size, kernel_stride
        )
        attn_output, topk_idx = compressed_attention(
            query,
            key,
            value,
            kernel_size,
            kernel_stride,
            block_size,
            topk,
            cu_seqlens,
            compress_cu_seqlens,
            seqlens.max().item(),
            compress_seqlens.max().item(),
            None,
            init_blocks,
            local_blocks,
        )
    else:
        batch_size = cu_seqlens.shape[0] - 1
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + step
        attn_output, topk_idx = compressed_attention_decode(
            query,
            cache.compress_kv_cache[
                0, :batch_size, : cache.compress_kv_len[:batch_size].max()
            ],
            cache.compress_kv_cache[
                1, :batch_size, : cache.compress_kv_len[:batch_size].max()
            ],
            seqlens,
            cache.compress_kv_len[:batch_size],
            kernel_size,
            kernel_stride,
            block_size,
            topk,
            init_blocks,
            local_blocks,
        )
    return attn_output, topk_idx


def topk_sparse_attention_infer(
    cu_seqlens,
    step,
    query,
    key,
    value,
    cache,
    topk_idx,
    block_size,
):
    if step == 0:
        attn_output = topk_sparse_attention(
            query, key, value, topk_idx, block_size, cu_seqlens
        )
    else:
        batch_size = cu_seqlens.shape[0] - 1
        attn_output = topk_sparse_attention_decode(
            query,
            cache.sparse_kv_cache[0, :batch_size],
            cache.sparse_kv_cache[1, :batch_size],
            topk_idx,
            block_size,
            cache.sparse_kv_len[:batch_size],
        )
    return attn_output


def sliding_window_attention_infer(
    cu_seqlens, step, query, key, value, cache, window_size
):
    if step == 0:
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        attn_output = flash_attn_varlen_func(
            query,
            key,
            value,
            cu_seqlens,
            cu_seqlens,
            seqlens.max().item(),
            seqlens.max().item(),
            causal=True,
            window_size=(window_size, -1),
        )
    else:
        batch_size = cu_seqlens.shape[0] - 1
        attn_output = flash_attention_decode(
            query,
            cache.sliding_kv_cache[0, :batch_size],
            cache.sliding_kv_cache[1, :batch_size],
            torch.minimum(
                cache.sliding_kv_len,
                torch.zeros_like(cache.sliding_kv_len) + window_size,
            )[:batch_size],
        )
    return attn_output


================================================
FILE: native_sparse_attention/infer/nsa_inference.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Tuple, Callable, Optional
from native_sparse_attention.infer.inference_func import (
    compress_infer,
    compressed_attention_infer,
    topk_sparse_attention_infer,
    sliding_window_attention_infer,
)


def nsa_infer(
    cu_seqlens: torch.Tensor,
    step: int,
    # qkv for three parts
    query: torch.Tensor,
    key: torch.Tensor,  # prefill: [total_len, num_heads, head_dim], decode: [batch_size, num_heads, head_dim]
    value: torch.Tensor,
    gate_value: torch.Tensor,  # prefill: [total_len, num_heads, 3], decode: [batch_size, num_heads, 3]
    # rope and kv cache
    rope,
    cache,
    # weight for nsa compress
    compress_weight: Tuple[
        torch.Tensor, torch.Tensor
    ],  # compress weight for key and value
    compress_func: Tuple[Callable, Callable],  # compress function for key and value
    intra_block_pe: Optional[torch.Tensor],
    # nsa parameters
    kernel_size: int,
    kernel_stride: int,
    block_size: int,
    topk: int,
    init_blocks: int,
    local_blocks: int,
    window_size: int,
) -> torch.Tensor:
    """Inference function for native sparse attention. Support prefill and decode with kv cache.

    Args:
        cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
        step (int): current inference step, step == 0 means prefill, step > 0 means decode step.
        query (torch.Tensor): for prefill, shape [total_len, num_q_heads, head_dim]; for decode, shape [batch_size, num_q_heads, head_dim]
        key (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]
        value (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]
        gate_value (torch.Tensor): for prefill, shape [total_len, num_heads, 3]; for decode, shape [batch_size, num_heads, 3]
        rope (RotaryEmbedding): rope module, see native_sparse_attention.module.rope.RotaryEmbedding for details
        cache (NSACache): kv cache, seed native_sparse_attention.module.kv_cache.NSACache for details
        compress_weight (Tuple[torch.Tensor, torch.Tensor]): compress weight of key and value respectively
        compress_func (Tuple[Callable, Callable]): compress functions for key and value respectively
        intra_block_pe (Optional[torch.Tensor]): intra-block positonal embedding for compression, set to None if don't use it
        kernel_size (int): kernel size of compression
        kernel_stride (int): kernel stride ofr compression
        block_size (int): block size of sparse attention
        topk (int): topk of sparse attention
        init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention
        local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention
        window_size (int): window size for sliding window attention

    Returns:
        torch.Tensor: native sparse attention output, same shape as input query
    """
    # reset kv cache at the begining of prefilling
    if step == 0:
        cache.reset()
    # prepare for compress
    cache.prepare_compress(cu_seqlens, step, key, value)
    # compressed key and value before rope
    compress_key, compress_value, compress_cu_seqlens = compress_infer(
        cu_seqlens,
        step,
        key,
        value,
        cache,
        compress_weight,
        compress_func,
        intra_block_pe,
        kernel_size,
        kernel_stride,
    )
    # do rope
    query = rope(query, cu_seqlens, step)
    if step == 0:
        compress_key = rope(
            compress_key, compress_cu_seqlens, step, stride=cache.kernel_stride
        )
    else:
        compress_key = rope(
            compress_key, compress_cu_seqlens, 1, stride=cache.kernel_stride
        )
    key = rope(key, cu_seqlens, step)
    # update kv cache
    cache.update_kv(
        cu_seqlens,
        step,
        compress_key,
        compress_value,
        key,
        value,
        key,
        value,
    )
    # compressed attention
    compress_attn_output, topk_idx = compressed_attention_infer(
        cu_seqlens,
        step,
        query,
        compress_key,
        compress_value,
        cache,
        kernel_size,
        kernel_stride,
        topk,
        block_size,
        init_blocks,
        local_blocks,
    )
    # topk sparse attention
    sparse_attn_output = topk_sparse_attention_infer(
        cu_seqlens,
        step,
        query,
        key,
        value,
        cache,
        topk_idx,
        block_size,
    )
    # sliding window attention
    sliding_attn_output = sliding_window_attention_infer(
        cu_seqlens, step, query, key, value, cache, window_size
    )
    # combine 3 attn output
    attn_output = (
        gate_value[..., 0, None] * compress_attn_output
        + gate_value[..., 1, None] * sparse_attn_output
        + gate_value[..., 2, None] * sliding_attn_output
    )
    return attn_output


================================================
FILE: native_sparse_attention/model/README.md
================================================
# Guide for the ToyNSALlama Model

The `ToyNSALlama` model is a custom implementation of a Llama-like transformer architecture featuring a Native Sparse Attention (NSA) module. This guide explains how to integrate the NSA module into your own model.

## Overview

The `ToyNSALlama` model consists of:
- **Configuration**: Defined by `ToyNSALlamaConfig` (model structure parameters) and `InferenceConfig` (inference-specific parameters).
- **Components**: An embedding layer, multiple NativeSparseAttention modules, Feed-Forward Network (FFN) modules, normalization layers, and a language model head.

## Step-by-Step Instructions

### 1. Import Necessary Modules
```python
import torch
import torch.nn as nn
from native_sparse_attention.model import ToyNSALlama, ToyNSALlamaConfig, InferenceConfig
```

### 2. Define Configurations
Create instances of `ToyNSALlamaConfig` and `InferenceConfig` to set model and inference parameters.

#### Model Configuration
The model configuration aligns with the Transformers Llama model configuration. Adjust the following parameters to control the NSA module’s sparsity:
- `compress_type`: Compression method for keys/values. Supported options: `avgpool`, `weightedpool`, `linear`.
- `kernel_size` & `kernel_stride`: `kernel_size` determines how many tokens are compressed into one; `kernel_stride` sets the sliding window stride (must be divisible by `kernel_size`).
- `block_size`: Block size for sparse attention (recommended: 64 or 128).
- `topk`, `init_blocks`, `local_blocks`: `topk` specifies the number of blocks selected in sparse attention; `init_blocks` and `local_blocks` define the number of initial and local blocks that must be selected.
- `window_size`: Size of the sliding window for attention.

Example:
```python
config = ToyNSALlamaConfig(
    hidden_size=4096,
    intermediate_size=14336,
    num_hidden_layers=8,
    num_attention_heads=32,
    num_key_value_heads=2,
    head_dim=128,
    vocab_size=128288,
    max_position_embeddings=131072,
    rope_theta=500000.0,
    rope_scaling={
        "factor": 8.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    compress_type="weightedpool",
    kernel_size=32,
    kernel_stride=16,
    block_size=64,
    topk=8,
    init_blocks=1,
    local_blocks=2,
    window_size=512,
)
```

#### Inference Configuration
This configuration applies during inference, initializing the Key-Value (KV) Cache based on these settings. The full KV cache size is calculated as `max_batch_size × max_length × num_kv_heads × num_layers × 2 × 2` bytes. Currently, only greedy decoding is supported as an example.

Example:
```python
inference_config = InferenceConfig(
    max_batch_size=4,
    max_length=8192,
    max_new_tokens=128,
)
```

### 3. Initialize the Model
Instantiate the model and move it to the GPU with the appropriate data type (currently, only `bfloat16` is supported).

```python
model = ToyNSALlama(config, inference_config).cuda().to(torch.bfloat16)
```

### 4. Forward & Generate
The model supports two methods:
- **`forward`**: Accepts `input_ids` and `cu_seqlens`, returning final logits after the language model head. Use this for training or evaluation.
- **`generate`**: Accepts `input_ids` and `cu_seqlens`, generating output tokens via greedy sampling. This demonstrates KV cache usage for token generation (pre-filling and decoding).

Example:
```python
# Example input
batch_size = 4
seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
input_ids = torch.randint(0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda")
print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")

# Example forward
logits = model(input_ids, cu_seqlens)
print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")

# Example generate
output_tokens = model.generate(input_ids, cu_seqlens)
print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
```

## Toy Llama Model with Self-Attention
A simpler toy model with the Llama structure is available in `native_sparse_attention/model/toy_llama.py`. Compare `ToyLlama` and `ToyNSALlama` to see how to adapt a self-attention model into an NSA-based model.

The primary difference lies in replacing the `SelfAttention` module with the `NativeSparseAttention` module, along with updates to the KV cache and inference function. These changes are straightforward and easy to implement.


================================================
FILE: native_sparse_attention/model/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from native_sparse_attention.model.toy_llama import (
    ToyLlamaConfig,
    InferenceConfig,
    ToyLlama,
)
from native_sparse_attention.model.toy_nsa_llama import (
    ToyNSALlamaConfig,
    InferenceConfig,
    ToyNSALlama,
)

__all__ = [
    "ToyLlamaConfig",
    "ToyNSALlamaConfig",
    "InferenceConfig",
    "ToyLlama",
    "ToyNSALlama",
]


================================================
FILE: native_sparse_attention/model/toy_llama.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from native_sparse_attention.module import SelfAttention, RopeConfig, KVCache


@dataclass
class ToyLlamaConfig:
    # embedding config
    vocab_size: int = 128288
    max_position_embeddings: int = 131072
    # model config
    hidden_size: int = 4096
    intermediate_size: int = 14336
    num_hidden_layers: int = 32
    num_attention_heads: int = 32
    num_key_value_heads: int = 2
    head_dim: int = 128
    # rope config
    rope_theta: float = 500000.0
    rope_scaling: dict = field(
        default_factory=lambda: {
            "factor": 8.0,
            "high_freq_factor": 4.0,
            "low_freq_factor": 1.0,
            "original_max_position_embeddings": 8192,
            "rope_type": "llama3",
        }
    )


@dataclass
class InferenceConfig:
    max_batch_size: int = 32
    max_length: int = 8192
    max_new_tokens: int = 128


class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


class FFN(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class ToyLlamaLayer(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        num_q_heads: int,
        num_kv_heads: int,
        head_dim: int,
        rope_config: RopeConfig,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.rope_config = rope_config
        self.attn_norm = RMSNorm(self.hidden_size)
        self.self_attn = SelfAttention(
            hidden_size=self.hidden_size,
            num_q_heads=self.num_q_heads,
            num_kv_heads=self.num_kv_heads,
            head_dim=self.head_dim,
            rope_config=rope_config,
        )
        self.ffn_norm = RMSNorm(self.hidden_size)
        self.ffn = FFN(
            hidden_size=self.hidden_size, intermediate_size=self.intermediate_size
        )

    def forward(self, x, cu_seqlens):
        x = x + self.self_attn(self.attn_norm(x), cu_seqlens)
        x = x + self.ffn(self.ffn_norm(x))
        return x

    @torch.no_grad()
    def inference(self, x, cu_seqlens, step, kv_cache):
        x = x + self.self_attn.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)
        x = x + self.ffn(self.ffn_norm(x))
        return x


class ToyLlama(nn.Module):
    def __init__(
        self, config: ToyLlamaConfig, inference_config: Optional[InferenceConfig] = None
    ):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
        self.rope_config = RopeConfig(
            head_dim=self.config.head_dim,
            rope_theta=self.config.rope_theta,
            rope_scaling=self.config.rope_scaling,
        )
        self.layers = nn.ModuleList(
            [
                ToyLlamaLayer(
                    hidden_size=self.config.hidden_size,
                    intermediate_size=self.config.intermediate_size,
                    num_q_heads=self.config.num_attention_heads,
                    num_kv_heads=self.config.num_key_value_heads,
                    head_dim=self.config.head_dim,
                    rope_config=RopeConfig(
                        self.config.max_position_embeddings,
                        self.config.head_dim,
                        self.config.rope_theta,
                        self.config.rope_scaling,
                    ),
                )
                for _ in range(self.config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(self.config.hidden_size)
        self.lm_head = nn.Linear(
            self.config.hidden_size, self.config.vocab_size, bias=False
        )

        # inference config and kv cache
        self.inference_config = inference_config
        self.kv_cache = None

    def forward(
        self,
        input_ids: torch.LongTensor,  # shape: [total_length, ]
        cu_seqlens: torch.LongTensor,  # shape: [batch_size + 1, ]
    ):
        # embedding
        x = self.embedding(input_ids).to(torch.bfloat16)
        # layers
        for layer in self.layers:
            x = layer(x, cu_seqlens)
        # final norm
        x = self.norm(x)
        # lanugauge head
        x = self.lm_head(x).to(torch.float32)  # [total_len, vocab_size]
        return x

    @torch.no_grad()
    def inference(
        self,
        input_ids: torch.LongTensor,  # prefill shape: [total_length, ]; decode shape: [batch_size, ]
        cu_seqlens: torch.LongTensor,  # shape: [batch_size + 1, ]
        step: int,
    ):
        # set kv cache if self.kv_cache is None
        if self.kv_cache is None:
            self.kv_cache = [
                KVCache(
                    max_batch_size=self.inference_config.max_batch_size,
                    max_length=self.inference_config.max_length,
                    num_kv_heads=self.config.num_key_value_heads,
                    head_dim=self.config.head_dim,
                    dtype=torch.bfloat16,
                    device="cuda",
                )
                for _ in range(self.config.num_hidden_layers)
            ]
        # embedding
        x = self.embedding(input_ids).to(torch.bfloat16)
        # layers
        for i, layer in enumerate(self.layers):
            x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])
        # final norm
        x = self.norm(x)
        # lanugauge head
        if step == 0:
            x = x[cu_seqlens[1:] - 1, :]
        x = self.lm_head(x).to(torch.float32)  # [total_len, vocab_size]
        return x

    def generate(
        self,
        input_ids: torch.LongTensor,
        cu_seqlens: torch.LongTensor,
        max_new_tokens: int = -1,
    ):
        output_tokens = []
        if max_new_tokens <= 0:
            max_new_tokens = self.inference_config.max_new_tokens
        for step in range(max_new_tokens):
            logits = self.inference(
                input_ids, cu_seqlens, step
            )  # shape: [batch_size, vocab_size]
            next_token = torch.argmax(logits, dim=-1)  # shape: [batch_size, ]
            input_ids = next_token
            output_tokens.append(next_token)
        output_tokens = torch.stack(
            output_tokens, dim=1
        )  # shape: [batch_size, max_new_tokens]
        return output_tokens


if __name__ == "__main__":
    torch.manual_seed(42)
    # initialize model
    config = ToyLlamaConfig(
        hidden_size=4096,
        intermediate_size=14336,
        num_hidden_layers=8,
        num_attention_heads=32,
        num_key_value_heads=2,
        head_dim=128,
        rope_theta=500000.0,
        rope_scaling={
            "factor": 8.0,
            "high_freq_factor": 4.0,
            "low_freq_factor": 1.0,
            "original_max_position_embeddings": 8192,
            "rope_type": "llama3",
        },
    )
    inference_config = InferenceConfig(
        max_batch_size=4,
        max_length=8192,
        max_new_tokens=128,
    )
    model = ToyLlama(config, inference_config).cuda().bfloat16()
    print(f"\nMODEL CONFIG:\n{config}\n")
    print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
    print(f"\nMODEL:\n{model}\n")

    # example input
    batch_size = 4
    seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
    input_ids = torch.randint(
        0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
    )
    print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")

    # example output
    logits = model(input_ids, cu_seqlens)
    print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")

    # example generate
    output_tokens = model.generate(input_ids, cu_seqlens, 64)
    print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")


================================================
FILE: native_sparse_attention/model/toy_nsa_llama.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from native_sparse_attention.module import NativeSparseAttention, RopeConfig, NSACache


@dataclass
class ToyNSALlamaConfig:
    # embedding config
    vocab_size: int = 128288
    max_position_embeddings: int = 131072
    # model config
    hidden_size: int = 4096
    intermediate_size: int = 14336
    num_hidden_layers: int = 32
    num_attention_heads: int = 32
    num_key_value_heads: int = 2
    head_dim: int = 128
    # rope config
    rope_theta: float = 500000.0
    rope_scaling: dict = field(
        default_factory=lambda: {
            "factor": 8.0,
            "high_freq_factor": 4.0,
            "low_freq_factor": 1.0,
            "original_max_position_embeddings": 8192,
            "rope_type": "llama3",
        }
    )
    # nsa config
    compress_type: str = "weightedpool"
    kernel_size: int = 32
    kernel_stride: int = 16
    block_size: int = 64
    topk: int = 16
    init_blocks: int = 1
    local_blocks: int = 2
    window_size: int = 512


@dataclass
class InferenceConfig:
    max_batch_size: int = 32
    max_length: int = 8192
    max_new_tokens: int = 128


class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


class FFN(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class ToyNSALlamaLayer(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        num_q_heads: int,
        num_kv_heads: int,
        head_dim: int,
        compress_type: str,
        kernel_size: int,
        kernel_stride: int,
        block_size: int,
        topk: int,
        init_blocks: int,
        local_blocks: int,
        window_size: int,
        rope_config: RopeConfig,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.compress_type = compress_type
        self.kernel_size = kernel_size
        self.kernel_stride = kernel_stride
        self.block_size = block_size
        self.topk = topk
        self.init_blocks = init_blocks
        self.local_blocks = local_blocks
        self.window_size = window_size
        self.rope_config = rope_config
        self.attn_norm = RMSNorm(self.hidden_size)
        self.nsa = NativeSparseAttention(
            hidden_size=self.hidden_size,
            num_q_heads=self.num_q_heads,
            num_kv_heads=self.num_kv_heads,
            head_dim=self.head_dim,
            compress_type=self.compress_type,
            kernel_size=self.kernel_size,
            kernel_stride=self.kernel_stride,
            block_size=self.block_size,
            topk=self.topk,
            init_blocks=self.init_blocks,
            local_blocks=self.local_blocks,
            window_size=self.window_size,
            rope_config=rope_config,
        )
        self.ffn_norm = RMSNorm(self.hidden_size)
        self.ffn = FFN(
            hidden_size=self.hidden_size, intermediate_size=self.intermediate_size
        )

    def forward(self, x, cu_seqlens):
        x = x + self.nsa(self.attn_norm(x), cu_seqlens)
        x = x + self.ffn(self.ffn_norm(x))
        return x

    @torch.no_grad()
    def inference(self, x, cu_seqlens, step, kv_cache):
        x = x + self.nsa.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)
        x = x + self.ffn(self.ffn_norm(x))
        return x


class ToyNSALlama(nn.Module):
    def __init__(
        self,
        config: ToyNSALlamaConfig,
        inference_config: Optional[InferenceConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
        self.rope_config = RopeConfig(
            head_dim=self.config.head_dim,
            rope_theta=self.config.rope_theta,
            rope_scaling=self.config.rope_scaling,
        )
        self.layers = nn.ModuleList(
            [
                ToyNSALlamaLayer(
                    hidden_size=self.config.hidden_size,
                    intermediate_size=self.config.intermediate_size,
                    num_q_heads=self.config.num_attention_heads,
                    num_kv_heads=self.config.num_key_value_heads,
                    head_dim=self.config.head_dim,
                    compress_type=self.config.compress_type,
                    kernel_size=self.config.kernel_size,
                    kernel_stride=self.config.kernel_stride,
                    block_size=self.config.block_size,
                    topk=self.config.topk,
                    init_blocks=self.config.init_blocks,
                    local_blocks=self.config.local_blocks,
                    window_size=self.config.window_size,
                    rope_config=RopeConfig(
                        self.config.max_position_embeddings,
                        self.config.head_dim,
                        self.config.rope_theta,
                        self.config.rope_scaling,
                    ),
                )
                for _ in range(self.config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(self.config.hidden_size)
        self.lm_head = nn.Linear(
            self.config.hidden_size, self.config.vocab_size, bias=False
        )

        # inference config and kv cache
        self.inference_config = inference_config
        self.kv_cache = None

    def forward(
        self,
        input_ids: torch.LongTensor,  # shape: [batch_size, max_length]
        cu_seqlens: torch.LongTensor,  # shape: [batch_size + 1, ]
    ):
        # embedding
        x = self.embedding(input_ids).to(torch.bfloat16)
        # layers
        for layer in self.layers:
            x = layer(x, cu_seqlens)
        # final norm
        x = self.norm(x)
        # lanugauge head
        x = self.lm_head(x).to(torch.float32)  # [total_len, vocab_size]
        return x

    @torch.no_grad()
    def inference(
        self,
        input_ids: torch.LongTensor,  # prefill shape: [total_length, ]; decode shape: [batch_size, ]
        cu_seqlens: torch.LongTensor,  # shape: [batch_size + 1, ]
        step: int,
    ):
        # set kv cache if self.kv_cache is None
        if self.kv_cache is None:
            self.kv_cache = [
                NSACache(
                    max_batch_size=self.inference_config.max_batch_size,
                    max_length=self.inference_config.max_length,
                    num_kv_heads=self.config.num_key_value_heads,
                    head_dim=self.config.head_dim,
                    kernel_size=self.config.kernel_size,
                    kernel_stride=self.config.kernel_stride,
                    window_size=self.config.window_size,
                    dtype=torch.bfloat16,
                    device="cuda",
                )
                for _ in range(self.config.num_hidden_layers)
            ]
        # embedding
        x = self.embedding(input_ids).to(torch.bfloat16)
        # layers
        for i, layer in enumerate(self.layers):
            x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])
        # final norm
        x = self.norm(x)
        # lanugauge head
        if step == 0:
            x = x[cu_seqlens[1:] - 1, :]
        x = self.lm_head(x).to(torch.float32)  # [total_len, vocab_size]
        return x

    def generate(
        self,
        input_ids: torch.LongTensor,
        cu_seqlens: torch.LongTensor,
        max_new_tokens: int = -1,
    ):
        output_tokens = []
        if max_new_tokens <= 0:
            max_new_tokens = self.inference_config.max_new_tokens
        for step in range(max_new_tokens):
            logits = self.inference(
                input_ids, cu_seqlens, step
            )  # shape: [batch_size, vocab_size]
            next_token = torch.argmax(logits, dim=-1)  # shape: [batch_size, ]
            input_ids = next_token
            output_tokens.append(next_token)
        output_tokens = torch.stack(
            output_tokens, dim=1
        )  # shape: [batch_size, max_new_tokens]
        return output_tokens


if __name__ == "__main__":
    torch.manual_seed(42)
    # initialize model
    config = ToyNSALlamaConfig(
        hidden_size=4096,
        intermediate_size=14336,
        num_hidden_layers=8,
        num_attention_heads=32,
        num_key_value_heads=2,
        head_dim=128,
        rope_theta=500000.0,
        rope_scaling={
            "factor": 8.0,
            "high_freq_factor": 4.0,
            "low_freq_factor": 1.0,
            "original_max_position_embeddings": 8192,
            "rope_type": "llama3",
        },
        compress_type="weightedpool",
        kernel_size=32,
        kernel_stride=16,
        block_size=64,
        topk=8,
        init_blocks=1,
        local_blocks=2,
        window_size=512,
    )
    inference_config = InferenceConfig(
        max_batch_size=4,
        max_length=8192,
        max_new_tokens=128,
    )
    model = ToyNSALlama(config, inference_config).cuda().bfloat16()
    print(f"\nMODEL CONFIG:\n{config}\n")
    print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
    print(f"\nMODEL:\n{model}\n")

    # example input
    batch_size = 4
    seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
    input_ids = torch.randint(
        0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
    )
    print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")

    # example output
    logits = model(input_ids, cu_seqlens)
    print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")

    # example generate
    output_tokens = model.generate(input_ids, cu_seqlens, 64)
    print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")


================================================
FILE: native_sparse_attention/module/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from native_sparse_attention.module.native_sparse_attention import NativeSparseAttention
from native_sparse_attention.module.self_attention import SelfAttention
from native_sparse_attention.module.rope import RotaryEmbedding, RopeConfig
from native_sparse_attention.module.kv_cache import NSACache, KVCache

__all__ = [
    "SelfAttention",
    "NativeSparseAttention",
    "RotaryEmbedding",
    "RopeConfig",
    "NSACache",
]


================================================
FILE: native_sparse_attention/module/kv_cache.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import triton
import triton.language as tl
from typing import Union
from native_sparse_attention.ops.triton.utils import get_compressed_seqlens


class KVCache:
    def __init__(
        self,
        max_batch_size: int,
        max_length: int,
        num_kv_heads: int,
        head_dim: int,
        dtype: torch.dtype,
        device: Union[str, torch.device],
    ):
        self.max_batch_size = max_batch_size
        self.max_length = max_length
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.dtype = dtype
        self.device = device

        # alloc kv cache tensor for topk sparse attention
        self.kv_cache = torch.zeros(
            2,
            self.max_batch_size,
            self.max_length,
            self.num_kv_heads,
            self.head_dim,
            dtype=self.dtype,
            device=self.device,
        )
        self.kv_len = torch.zeros(
            self.max_batch_size, dtype=torch.int32, device=self.device
        )

    def reset(self):
        self.kv_cache.zero_()
        self.kv_len.zero_()

    def update_kv(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        key: torch.Tensor,
        value: torch.Tensor,
    ):
        if step == 0:
            self._update_kv_prefill(
                cu_seqlens,
                step,
                key,
                value,
            )
        else:
            self._update_kv_decode(
                cu_seqlens,
                step,
                key,
                value,
            )

    def _update_kv_prefill(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        key: torch.Tensor,
        value: torch.Tensor,
    ):
        assert step == 0
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        batch_size = seqlens.shape[0]
        # sparse part kv, shape check
        total_len, num_heads, head_dim = key.shape
        assert key.shape == value.shape
        assert num_heads == self.num_kv_heads and head_dim == self.head_dim
        assert total_len == cu_seqlens[-1].item()
        # fill sparse part kv cache
        seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:]
        _fill_kv_cache(
            self.kv_cache,
            key,
            value,
            seq_start,
            seq_end,
        )
        self.kv_len[:batch_size] = seqlens

    def _update_kv_decode(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        key: torch.Tensor,
        value: torch.Tensor,
    ):
        assert step > 0
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        # sparse part kv, shape check
        batch_size, num_heads, head_dim = key.shape
        assert batch_size == seqlens.shape[0]
        assert key.shape == value.shape
        assert num_heads == self.num_kv_heads and head_dim == self.head_dim
        # fill sparse part kv cache
        brange = torch.arange(batch_size, dtype=torch.int32, device=key.device)
        self.kv_cache[0, :batch_size][brange, self.kv_len[:batch_size]] = key
        self.kv_cache[1, :batch_size][brange, self.kv_len[:batch_size]] = value
        self.kv_len[:batch_size] += 1


class NSACache:
    """KV cache manager for native sparse attention.
    Args:
        max_batch_size (int): max batch size
        max_length (int): max length, including prompt len and reponse len
        num_kv_heads (int): number of key/value heads
        head_dim (int): head dim
        kernel_size (int): kernel size of compression
        kernel_stride (int): kernel stride ofr compression
        window_size (int): window size for sliding window attention
        dtype (torch.dtype): data type for kv cache, should be same as model weight dtype
        device (Union[str, torch.device]): default to 'cuda'

    Methods:
        reset: reset kv cache, should be called before prefilling
        prepare_compress: store keys/values for compression, should be called before key/value compression at both prefilling and decoding
        update_kv: update key/value cache, should be called after rope
    """

    def __init__(
        self,
        max_batch_size: int,
        max_length: int,
        num_kv_heads: int,
        head_dim: int,
        kernel_size: int,
        kernel_stride: int,
        window_size: int,
        dtype: torch.dtype,
        device: Union[str, torch.device] = "cuda",
    ):
        self.max_batch_size = max_batch_size
        self.max_length = max_length
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.kernel_size = kernel_size
        self.kernel_stride = kernel_stride
        self.window_size = window_size
        self.dtype = dtype
        self.device = device

        # alloc kv cache tensor for topk sparse attention
        self.sparse_kv_cache = torch.zeros(
            2,
            self.max_batch_size,
            self.max_length,
            self.num_kv_heads,
            self.head_dim,
            dtype=self.dtype,
            device=self.device,
        )
        self.sparse_kv_len = torch.zeros(
            self.max_batch_size, dtype=torch.int32, device=self.device
        )

        # alloc kv cache tensor for compressed attention
        self.max_comp_length = (
            self.max_length - self.kernel_size
        ) // self.kernel_stride + 1
        self.compress_kv_cache = torch.zeros(
            2,
            self.max_batch_size,
            self.max_comp_length,
            self.num_kv_heads,
            self.head_dim,
            dtype=self.dtype,
            device=self.device,
        )
        self.compress_kv_len = torch.zeros(
            self.max_batch_size, dtype=torch.int32, device=self.device
        )
        self.before_compress_kv_cache = torch.zeros(
            2,
            self.max_batch_size,
            self.kernel_size,
            self.num_kv_heads,
            self.head_dim,
            dtype=self.dtype,
            device=self.device,
        )
        self.before_compress_kv_len = torch.zeros(
            self.max_batch_size, dtype=torch.int32, device=self.device
        )

        # alloc kv cache for sliding window attention
        self.sliding_kv_cache = torch.zeros(
            2,
            self.max_batch_size,
            self.window_size,
            self.num_kv_heads,
            self.head_dim,
            dtype=self.dtype,
            device=self.device,
        )
        self.sliding_kv_len = torch.zeros(
            self.max_batch_size, dtype=torch.int32, device=self.device
        )

    def reset(self):
        self.compress_kv_cache.zero_()
        self.compress_kv_len.zero_()
        self.before_compress_kv_cache.zero_()
        self.before_compress_kv_len.zero_()
        self.sparse_kv_cache.zero_()
        self.sparse_kv_len.zero_()
        self.sliding_kv_cache.zero_()
        self.sliding_kv_len.zero_()

    def prepare_compress(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        key: torch.Tensor,
        value: torch.Tensor,
    ):
        if step == 0:
            self._prepare_compress_prefill(cu_seqlens, step, key, value)
        else:
            self._prepare_compress_decode(cu_seqlens, step, key, value)

    def _prepare_compress_prefill(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        key: torch.Tensor,
        value: torch.Tensor,
    ):
        assert step == 0
        # compress part kv, shape check
        batch_size = cu_seqlens.shape[0] - 1
        total_len, num_heads, head_dim = key.shape
        assert key.shape == value.shape
        assert num_heads == self.num_kv_heads and head_dim == self.head_dim
        comp_seqlens, comp_cu_seqlens = get_compressed_seqlens(
            cu_seqlens, self.kernel_size, self.kernel_stride
        )
        assert total_len == cu_seqlens[-1].item()
        # fill tmp cache
        seq_start = cu_seqlens[:-1] + comp_seqlens * self.kernel_stride
        seq_end = cu_seqlens[1:]
        _fill_kv_cache(
            self.before_compress_kv_cache,
            key,
            value,
            seq_start,
            seq_end,
        )
        self.before_compress_kv_len[:batch_size] = seq_end - seq_start

    def _prepare_compress_decode(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        key: torch.Tensor,
        value: torch.Tensor,
    ):
        assert step > 0
        # compress part kv, shape check
        batch_size, num_heads, head_dim = key.shape
        assert key.shape == value.shape
        assert num_heads == self.num_kv_heads and head_dim == self.head_dim
        assert batch_size == cu_seqlens.shape[0] - 1
        # sequence with full before_compress_kv
        idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze()
        self.before_compress_kv_cache[
            :, idx, : self.kernel_size - self.kernel_stride, :, :
        ] = self.before_compress_kv_cache[:, idx, self.kernel_stride :, :, :]
        self.before_compress_kv_len[idx] -= self.kernel_stride
        # fill new kv
        brange = torch.arange(batch_size, dtype=torch.int32, device=key.device)
        self.before_compress_kv_cache[0, :batch_size][
            brange, self.before_compress_kv_len[:batch_size]
        ] = key
        self.before_compress_kv_cache[1, :batch_size][
            brange, self.before_compress_kv_len[:batch_size]
        ] = value
        # update kv len
        self.before_compress_kv_len[:batch_size] += 1

    def update_kv(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        compress_key: torch.Tensor,
        compress_value: torch.Tensor,
        sparse_key: torch.Tensor,
        sparse_value: torch.Tensor,
        sliding_key: torch.Tensor,
        sliding_value: torch.Tensor,
    ):
        if step == 0:
            self._update_kv_prefill(
                cu_seqlens,
                step,
                compress_key,
                compress_value,
                sparse_key,
                sparse_value,
                sliding_key,
                sliding_value,
            )
        else:
            self._update_kv_decode(
                cu_seqlens,
                step,
                compress_key,
                compress_value,
                sparse_key,
                sparse_value,
                sliding_key,
                sliding_value,
            )

    def _update_kv_prefill(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        compress_key: torch.Tensor,
        compress_value: torch.Tensor,
        sparse_key: torch.Tensor,
        sparse_value: torch.Tensor,
        sliding_key: torch.Tensor,
        sliding_value: torch.Tensor,
    ):
        assert step == 0
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        batch_size = seqlens.shape[0]
        # sparse part kv, shape check
        total_len, num_heads, head_dim = sparse_key.shape
        assert sparse_key.shape == sparse_value.shape
        assert num_heads == self.num_kv_heads and head_dim == self.head_dim
        assert total_len == cu_seqlens[-1].item()
        # compress part kv, shape check
        total_comp_len, num_heads, head_dim = compress_key.shape
        assert compress_key.shape == compress_value.shape
        assert num_heads == self.num_kv_heads and head_dim == self.head_dim
        comp_seqlens, comp_cu_seqlens = get_compressed_seqlens(
            cu_seqlens, self.kernel_size, self.kernel_stride
        )
        assert total_comp_len == comp_cu_seqlens[-1].item()
        # sliding window part kv, shape check
        total_len, num_heads, head_dim = sliding_key.shape
        assert sliding_key.shape == sliding_value.shape

        # fill compress part kv cache
        seq_start, seq_end = comp_cu_seqlens[:-1], comp_cu_seqlens[1:]
        _fill_kv_cache(
            self.compress_kv_cache,
            compress_key,
            compress_value,
            seq_start,
            seq_end,
        )
        self.compress_kv_len[:batch_size] = comp_seqlens
        # fill sparse part kv cache
        seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:]
        _fill_kv_cache(
            self.sparse_kv_cache,
            sparse_key,
            sparse_value,
            seq_start,
            seq_end,
        )
        self.sparse_kv_len[:batch_size] = seqlens
        # fill sliding part kv cache
        seq_start = torch.maximum(cu_seqlens[1:] - self.window_size, cu_seqlens[:-1])
        seq_end = cu_seqlens[1:]
        _fill_kv_cache(
            self.sliding_kv_cache,
            sliding_key,
            sliding_value,
            seq_start,
            seq_end,
        )
        self.sliding_kv_len[:batch_size] = seq_end - seq_start

    def _update_kv_decode(
        self,
        cu_seqlens: torch.Tensor,
        step: int,
        compress_key: torch.Tensor,
        compress_value: torch.Tensor,
        sparse_key: torch.Tensor,
        sparse_value: torch.Tensor,
        sliding_key: torch.Tensor,
        sliding_value: torch.Tensor,
    ):
        assert step > 0
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        # sparse part kv, shape check
        batch_size, num_heads, head_dim = sparse_key.shape
        assert batch_size == seqlens.shape[0]
        assert sparse_key.shape == sparse_value.shape
        assert num_heads == self.num_kv_heads and head_dim == self.head_dim
        # compress part kv, shape check
        batch_size, num_heads, head_dim = compress_key.shape
        assert batch_size == seqlens.shape[0]
        assert compress_key.shape == compress_value.shape
        assert num_heads == self.num_kv_heads and head_dim == self.head_dim
        # sliding window part kv, shape check
        total_len, num_heads, head_dim = sliding_key.shape
        assert sliding_key.shape == sliding_value.shape

        # fill compress part kv cache, only full block need to be compress
        idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze()
        self.compress_kv_cache[0][idx, self.compress_kv_len[idx]] = compress_key[idx]
        self.compress_kv_cache[1][idx, self.compress_kv_len[idx]] = compress_value[idx]
        self.compress_kv_len[idx] += 1
        # fill sparse part kv cache
        brange = torch.arange(batch_size, dtype=torch.int32, device=sparse_key.device)
        self.sparse_kv_cache[0, :batch_size][
            brange, self.sparse_kv_len[:batch_size]
        ] = sparse_key
        self.sparse_kv_cache[1, :batch_size][
            brange, self.sparse_kv_len[:batch_size]
        ] = sparse_value
        self.sparse_kv_len[:batch_size] += 1
        # fill sliding window kv cache
        self.sliding_kv_cache[0, :batch_size][
            brange, self.sliding_kv_len[:batch_size] % self.window_size
        ] = sliding_key
        self.sliding_kv_cache[1, :batch_size][
            brange, self.sliding_kv_len[:batch_size] % self.window_size
        ] = sliding_value
        self.sliding_kv_len[:batch_size] += 1


@triton.jit
def _fill_kv_cache_kernel(
    cache_ptr,
    k_ptr,
    v_ptr,
    seq_start,
    seq_end,
    head_dim,
    stride_c2,
    stride_cb,
    stride_cn,
    stride_ch,
    stride_cd,
    stride_kn,
    stride_kh,
    stride_kd,
    stride_vn,
    stride_vh,
    stride_vd,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_D: tl.constexpr,
):
    # get batch id and head id
    pid_2b = tl.program_id(0)
    pid_2 = pid_2b % 2
    pid_b = pid_2b // 2
    pid_h = tl.program_id(1)
    pid_n = tl.program_id(2)
    # get kv start and len after rmpad
    kv_start = tl.load(seq_start + pid_b)
    kv_end = tl.load(seq_end + pid_b)
    kv_len = kv_end - kv_start
    if pid_n * BLOCK_SIZE_N >= kv_len:
        return
    # load key or value
    if pid_2 == 0:
        kv_ptrs = tl.make_block_ptr(
            base=k_ptr + kv_start * stride_kn + pid_h * stride_kh,
            shape=(kv_len, head_dim),
            strides=(stride_kn, stride_kd),
            offsets=(pid_n * BLOCK_SIZE_N, 0),
            block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
            order=(1, 0),
        )
        kv = tl.load(kv_ptrs, boundary_check=(0, 1))
    else:
        kv_ptrs = tl.make_block_ptr(
            base=v_ptr + kv_start * stride_vn + pid_h * stride_vh,
            shape=(kv_len, head_dim),
            strides=(stride_vn, stride_vd),
            offsets=(pid_n * BLOCK_SIZE_N, 0),
            block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
            order=(1, 0),
        )
        kv = tl.load(kv_ptrs, boundary_check=(0, 1))
    # store to cache
    cache_ptrs = tl.make_block_ptr(
        base=cache_ptr + pid_2 * stride_c2 + pid_b * stride_cb + pid_h * stride_ch,
        shape=(kv_len, head_dim),
        strides=(stride_cn, stride_cd),
        offsets=(pid_n * BLOCK_SIZE_N, 0),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    tl.store(cache_ptrs, kv.to(cache_ptr.dtype.element_ty), boundary_check=(0, 1))


def _fill_kv_cache(
    kv_cache: torch.Tensor,  # shape: [2, b, n, h, d]
    key: torch.Tensor,  # shape: [l, h, d]
    value: torch.Tensor,  # shape: [l, h, d]
    seq_start: torch.Tensor,  # shape: [b]
    seq_end: torch.Tensor,  # shape: [b]
):
    total_len, num_heads, head_dim = key.shape
    batch_size = seq_start.shape[0]
    max_kv_len = (seq_end - seq_start).max().item()
    # no kv cache to fill
    if max_kv_len == 0:
        return
    BLOCK_SIZE_N = min(1024, triton.next_power_of_2(max_kv_len))
    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
    grid = (2 * batch_size, num_heads, triton.cdiv(max_kv_len, BLOCK_SIZE_N))
    _fill_kv_cache_kernel[grid](
        kv_cache,
        key,
        value,
        seq_start,
        seq_end,
        head_dim,
        kv_cache.stride(0),
        kv_cache.stride(1),
        kv_cache.stride(2),
        kv_cache.stride(3),
        kv_cache.stride(4),
        key.stride(0),
        key.stride(1),
        key.stride(2),
        value.stride(0),
        value.stride(1),
        value.stride(2),
        BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_D=BLOCK_SIZE_D,
    )


================================================
FILE: native_sparse_attention/module/native_sparse_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flash_attn import flash_attn_varlen_func
from native_sparse_attention.ops import (
    compressed_attention,
    topk_sparse_attention,
    avgpool_compress,
    weightedpool_compress,
    linear_compress,
)
from einops import rearrange
from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding
from native_sparse_attention.infer import nsa_infer
from native_sparse_attention.module.kv_cache import NSACache

COMPRESS_TYPE_TO_FUNC = {
    "avgpool": avgpool_compress,
    "weightedpool": weightedpool_compress,
    "linear": linear_compress,
}

COMPRESS_TYPE_TO_WEIGHT = {
    "avgpool": lambda num_heads, head_dim, kernel_size: None,
    "weightedpool": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(
        torch.zeros(num_heads, kernel_size)
    ),
    "linear": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(
        torch.zeros(num_heads, head_dim * kernel_size, head_dim)
    ),
}


class NativeSparseAttention(torch.nn.Module):
    """Native sparse attention module, support training and inference

    Args:
        compress_type (str): key value compression type, currently support ['linear', 'avgpool', 'weightedpool']
        hidden_size (int): hidden dimension
        num_q_heads (int): number of query heads
        num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads
        head_dim (int): head dim
        kernel_size (int): kernel size of compression
        kernel_stride (int): kernel stride ofr compression
        block_size (int): block size of sparse attention
        topk (int): topk of sparse attention
        init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention
        local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention
        window_size (int): window size for sliding window attention
        rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
        rope_device (str): device used to store rope freqs
    """

    def __init__(
        self,
        compress_type: str,
        hidden_size: int,
        num_q_heads: int,
        num_kv_heads: int,
        head_dim: int,
        kernel_size: int,
        kernel_stride: int,
        block_size: int,
        topk: int,
        init_blocks: int,
        local_blocks: int,
        window_size: int,
        rope_config: RopeConfig,
        rope_device: str = "cuda",
    ):
        super().__init__()
        # configs
        self.compress_type = compress_type
        self.hidden_size = hidden_size
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.kernel_size = kernel_size
        self.kernel_stride = kernel_stride
        self.block_size = block_size
        self.topk = topk
        self.init_blocks = init_blocks
        self.local_blocks = local_blocks
        self.window_size = window_size
        self.rope_config = rope_config
        assert self.head_dim == self.rope_config.head_dim

        # qkv proj and o proj
        self.proj_q = torch.nn.Linear(
            self.hidden_size, self.num_q_heads * self.head_dim, bias=False
        )
        self.proj_k = torch.nn.Linear(
            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
        )
        self.proj_v = torch.nn.Linear(
            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
        )
        self.proj_o = torch.nn.Linear(
            self.num_q_heads * self.head_dim, self.hidden_size, bias=False
        )

        # nsa compress func
        self.compress_func = COMPRESS_TYPE_TO_FUNC[self.compress_type]

        # nsa parameteres
        self.compress_key = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](
            num_kv_heads, head_dim, kernel_size
        )

        self.compress_value = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](
            num_kv_heads, head_dim, kernel_size
        )
        self.intra_block_pe = torch.nn.Parameter(
            torch.zeros(self.num_kv_heads, self.kernel_size, self.head_dim)
        )

        # gate function
        self.gate = torch.nn.Sequential(
            torch.nn.Linear(self.hidden_size, self.num_q_heads * 3, bias=False),
            torch.nn.Sigmoid(),
        )

        # rope
        self.rope = RotaryEmbedding(self.rope_config, device=rope_device)

        # init parameters
        self.init_params()

    def init_params(self):
        for p in self.parameters():
            if len(p.shape) > 1:
                torch.nn.init.xavier_uniform_(p)

    def forward(
        self,
        x: torch.Tensor,  # shape: [total_len, hidden_size]
        cu_seqlens: torch.Tensor,  # shape: [batch_size + 1]
    ):
        # dtype and shape check
        assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
        assert x.shape[-1] == self.hidden_size
        cu_seqlens = cu_seqlens.to(torch.int32)
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

        # qkv proj
        q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
        k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
        v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)

        # compressed key and value before rope
        compressed_k, compressed_cu_seqlens = self.compress_func(
            k,
            self.compress_key,
            cu_seqlens,
            self.kernel_size,
            self.kernel_stride,
            self.intra_block_pe,
        )
        compressed_v, _ = self.compress_func(
            v,
            self.compress_value,
            cu_seqlens,
            self.kernel_size,
            self.kernel_stride,
            None,
        )

        # do rope for query and compressed key
        q = self.rope(q, cu_seqlens)
        compressed_k = self.rope(
            compressed_k, compressed_cu_seqlens, stride=self.kernel_stride
        )

        # attention between query and compressed key value
        compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
        compressed_attn_output, topk_idx = compressed_attention(
            q,
            compressed_k,
            compressed_v,
            self.kernel_size,
            self.kernel_stride,
            self.block_size,
            self.topk,
            cu_seqlens,
            compressed_cu_seqlens,
            seqlens.max().item(),
            compressed_seqlens.max().item(),
            None,
            self.init_blocks,
            self.local_blocks,
        )

        # do rope for original key
        k = self.rope(k, cu_seqlens)

        # topk sparse attention
        sparse_attn_output = topk_sparse_attention(
            q, k, v, topk_idx, self.block_size, cu_seqlens, None
        )

        # sliding window attention
        sliding_attn_output = flash_attn_varlen_func(
            q,
            k,
            v,
            cu_seqlens,
            cu_seqlens,
            seqlens.max().item(),
            seqlens.max().item(),
            causal=True,
            window_size=(self.window_size, -1),
        )

        # gate average
        gate = self.gate(x)
        gate = rearrange(gate, "n (h g) -> n h g", g=3)
        attn_output = (
            gate[..., 0:1] * compressed_attn_output
            + gate[..., 1:2] * sparse_attn_output
            + gate[..., 2:3] * sliding_attn_output
        )

        # rearrange and output proj
        attn_output = rearrange(attn_output, "n h d -> n (h d)")
        attn_output = self.proj_o(attn_output)

        return attn_output

    @torch.no_grad()
    def inference(
        self,
        x: torch.Tensor,  # shape: [total_len, hidden_size]
        cu_seqlens: torch.Tensor,  # shape: [batch_size + 1]
        step: int,
        cache: NSACache,
    ):
        # dtype and shape check
        assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
        assert x.shape[-1] == self.hidden_size
        cu_seqlens = cu_seqlens.to(torch.int32)
        assert step >= 0
        if step == 0:
            assert x.shape[0] == cu_seqlens[-1]
        else:
            assert x.shape[0] == cu_seqlens.shape[0] - 1
        # qkv proj
        q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
        k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
        v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
        # gate proj
        gate = self.gate(x)
        gate = rearrange(gate, "n (h g) -> n h g", g=3)
        # nsa infer
        output = nsa_infer(
            cu_seqlens,
            step,
            q,
            k,
            v,
            gate,
            self.rope,
            cache,
            [self.compress_key, self.compress_value],
            [self.compress_func, self.compress_func],
            self.intra_block_pe,
            self.kernel_size,
            self.kernel_stride,
            self.block_size,
            self.topk,
            self.init_blocks,
            self.local_blocks,
            self.window_size,
        )
        # output proj
        output = rearrange(output, "n h d -> n (h d)")
        output = self.proj_o(output)
        return output


================================================
FILE: native_sparse_attention/module/rope.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from dataclasses import dataclass, field
from torch import nn
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS


# default to llama3.1 rope config
@dataclass
class RopeConfig:
    """Config for RotaryEmbedding, similar to transformers llama."""

    max_position_embeddings: int = 131072
    head_dim: int = 128
    rope_theta: float = 500000
    rope_scaling: dict = field(
        default_factory=lambda: {
            "factor": 8.0,
            "high_freq_factor": 4.0,
            "low_freq_factor": 1.0,
            "original_max_position_embeddings": 8192,
            "rope_type": "llama3",
        }
    )
    # useless, just for compatibility, please use head_dim instead
    hidden_size: int = 1
    num_attention_heads: int = 1

    def __post_init__(self):
        self.num_attention_heads = 1
        self.hidden_size = self.head_dim


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# copy and modify from modify from hugigngface transformers
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class RotaryEmbedding(nn.Module):
    """Rotary embedding

    Args:
        config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
        device (str): default to 'cuda'
    """

    cos = None
    sin = None

    def __init__(
        self, config: RopeConfig, device=torch.device(torch.cuda.current_device())
    ):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get(
                "rope_type", config.rope_scaling.get("type")
            )
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    def _dynamic_frequency_update(self, position_ids, device):
        """
        dynamic RoPE layers should recompute `inv_freq` in the following situations:
        1 - growing beyond the cached sequence length (allow scaling)
        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
        """
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.max_seq_len_cached:  # growth
            inv_freq, self.attention_scaling = self.rope_init_fn(
                self.config, device, seq_len=seq_len
            )
            self.register_buffer(
                "inv_freq", inv_freq, persistent=False
            )  # TODO joao: may break with compilation
            self.max_seq_len_cached = seq_len

        if (
            seq_len < self.original_max_seq_len
            and self.max_seq_len_cached > self.original_max_seq_len
        ):  # reset
            # This .to() is needed if the model has been moved to a device after being initialized (because
            # the buffer is automatically moved, but not the original copy)
            self.original_inv_freq = self.original_inv_freq.to(device)
            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
            self.max_seq_len_cached = self.original_max_seq_len

    @torch.no_grad()
    def generate_cos_sin(self, x: torch.Tensor, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = (
            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        )
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = (
            device_type
            if isinstance(device_type, str) and device_type != "mps"
            else "cpu"
        )
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (
                inv_freq_expanded.float() @ position_ids_expanded.float()
            ).transpose(1, 2)
            # # donot use this if use flash_attn
            # emb = torch.cat((freqs, freqs), dim=-1)
            emb = freqs
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = (cos * self.attention_scaling).to(dtype=x.dtype).squeeze(0)
        sin = (sin * self.attention_scaling).to(dtype=x.dtype).squeeze(0)

        # save cos sin
        RotaryEmbedding.cos = torch.cat([cos, cos], dim=-1)
        RotaryEmbedding.sin = torch.cat([sin, sin], dim=-1)

        return RotaryEmbedding.cos, RotaryEmbedding.sin

    @torch.no_grad()
    def generate_pos_embs(
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        seqlens: torch.Tensor,
        step: int = 0,
        stride: int = 1,
    ):
        if (
            RotaryEmbedding.cos is None
            or seqlens.max() + step > RotaryEmbedding.cos.shape[0]
        ):
            self.generate_cos_sin(
                x, torch.arange(seqlens.max() + step).to(x.device)[None, :]
            )

        cos_embs = []
        sin_embs = []
        bsz = len(cu_seqlens) - 1

        for i in range(bsz):
            if step == 0:  # prefilling
                r = cu_seqlens[i + 1] - cu_seqlens[i]
                cos_emb, sin_emb = (
                    RotaryEmbedding.cos[: r * stride : stride],
                    RotaryEmbedding.sin[: r * stride : stride],
                )
            elif step > 0:  # decoding
                r = cu_seqlens[i + 1] - cu_seqlens[i] + step - 1
                cos_emb, sin_emb = (
                    RotaryEmbedding.cos[r * stride : r * stride + 1],
                    RotaryEmbedding.sin[r * stride : r * stride + 1],
                )
            cos_embs.append(cos_emb)
            sin_embs.append(sin_emb)

        cos_embs = torch.cat(cos_embs, dim=0)
        sin_embs = torch.cat(sin_embs, dim=0)
        return cos_embs, sin_embs

    def forward(self, x, cu_seqlens, step=0, stride=1):
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        cos_embs, sin_embs = self.generate_pos_embs(
            x,
            cu_seqlens,
            seqlens,
            step=step,
            stride=stride,
        )
        N, H, D = x.shape[0], x.shape[-2], x.shape[-1]  # H: number of heads
        x = x * cos_embs.view(N, 1, D) + rotate_half(x) * sin_embs.view(N, 1, D)
        return x


================================================
FILE: native_sparse_attention/module/self_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flash_attn import flash_attn_varlen_func
from einops import rearrange
from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding
from native_sparse_attention.module.kv_cache import KVCache
from native_sparse_attention.ops import flash_attention_decode


class SelfAttention(torch.nn.Module):
    """self attention module

    Args:
        hidden_size (int): hidden dimension
        num_q_heads (int): number of query heads
        num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads
        head_dim (int): head dim
        rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
    """

    def __init__(
        self,
        hidden_size: int,
        num_q_heads: int,
        num_kv_heads: int,
        head_dim: int,
        rope_config: RopeConfig,
        rope_device: str = "cuda",
    ):
        super().__init__()
        # configs
        self.hidden_size = hidden_size
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.rope_config = rope_config
        assert self.head_dim == self.rope_config.head_dim

        # qkv proj and o proj
        self.proj_q = torch.nn.Linear(
            self.hidden_size, self.num_q_heads * self.head_dim, bias=False
        )
        self.proj_k = torch.nn.Linear(
            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
        )
        self.proj_v = torch.nn.Linear(
            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
        )
        self.proj_o = torch.nn.Linear(
            self.num_q_heads * self.head_dim, self.hidden_size, bias=False
        )
        # rope
        self.rope = RotaryEmbedding(self.rope_config, device=rope_device)

        # init parameters
        self.init_params()

    def init_params(self):
        for p in self.parameters():
            torch.nn.init.xavier_uniform_(p)

    def forward(
        self,
        x: torch.Tensor,  # shape: [total_len, hidden_size]
        cu_seqlens: torch.Tensor,  # shape: [batch_size + 1]
    ):
        # dtype and shape check
        assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
        assert x.shape[-1] == self.hidden_size
        cu_seqlens = cu_seqlens.to(torch.int32)
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

        # qkv proj
        q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
        k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
        v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)

        # do rope for query and compressed key
        q = self.rope(q, cu_seqlens)
        k = self.rope(k, cu_seqlens)

        # self attention
        attn_output = flash_attn_varlen_func(
            q,
            k,
            v,
            cu_seqlens,
            cu_seqlens,
            seqlens.max().item(),
            seqlens.max().item(),
            causal=True,
        )

        # rearrange and output proj
        attn_output = rearrange(attn_output, "n h d -> n (h d)")
        attn_output = self.proj_o(attn_output)

        return attn_output

    @torch.no_grad()
    def inference(
        self,
        x: torch.Tensor,  # shape: [total_len, hidden_size]
        cu_seqlens: torch.Tensor,  # shape: [batch_size + 1]
        step: int,
        cache: KVCache,
    ):
        # dtype and shape check
        assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
        assert x.shape[-1] == self.hidden_size
        cu_seqlens = cu_seqlens.to(torch.int32)
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
        assert step >= 0
        if step == 0:
            assert x.shape[0] == cu_seqlens[-1]
        else:
            assert x.shape[0] == cu_seqlens.shape[0] - 1
        batch_size = cu_seqlens.shape[0] - 1
        # qkv proj
        q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
        k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
        v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
        # do rope for query and compressed key
        q = self.rope(q, cu_seqlens, step)
        k = self.rope(k, cu_seqlens, step)
        # reset and update kv cache
        if step == 0:
            cache.reset()
        cache.update_kv(cu_seqlens, step, k, v)
        # self attention
        if step == 0:
            cu_seqlens_q = cu_seqlens_k = cu_seqlens
            max_seqlen_in_batch_q = max_seqlen_in_batch_k = seqlens.max().item()
            output = flash_attn_varlen_func(
                q,
                k,
                v,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                causal=True,
            )
        else:
            output = flash_attention_decode(
                q,
                cache.kv_cache[0, :batch_size],
                cache.kv_cache[1, :batch_size],
                cache.kv_len[:batch_size],
            )
        # rearrange and output proj
        output = rearrange(output, "n h d -> n (h d)")
        output = self.proj_o(output)
        return output


================================================
FILE: native_sparse_attention/ops/README.md
================================================
# Triton Functions for Native Sparse Attention

This folder provides efficient Triton-based implementations of components for Native Sparse Attention. This README introduces the available functions, explains how to set them up, and offers guidance on their usage.

---

## Overview of Functions

The functions are organized into two main categories:

1. **Compression Methods**: Techniques for compressing key and value tensors.
2. **Attention Mechanisms**: Methods for computing attention between queries and compressed key/value tensors, including top-k sparse attention.

---

## Function Descriptions

### Compression Methods

These functions compress key and value tensors using a sliding window approach. Within each window, `kernel_size` tokens are compressed into a single token, with a stride of `kernel_stride`. All compression functions share similar input parameters and return formats.

**Parameters:**
- `x`: Input tensor (`total_len, num_heads, head_dim`)
- `w`: Weight tensor (shape varies by compression method)
- `cu_seqlens`: Cumulative sequence lengths (`batch_size + 1`)
- `kernel_size`: Size of the compression window
- `kernel_stride`: Stride of the compression window
- `pe`: Optional positional embedding (`num_heads, kernel_size, head_dim`)

**Returns:**
- Compressed tensor (`total_compress_len, num_heads, head_dim`)
- Cumulative sequence lengths (`com_cu_seqlens`) for the compressed tensor

#### `weightedpool_compress`
Compresses the input tensor using weighted pooling, applying a weighted sum over each block:  
$\hat{k} = w_1 k_1 + \dots + w_m k_m$  
- **Weight shape**: `(num_heads, kernel_size)`

#### `avgpool_compress`
Compresses the input tensor using average pooling:  
$\hat{k} = (k_1 + \dots + k_m) / m$  
- **Weight**: Must be `None`

#### `linear_compress`
Compresses the input tensor via linear projection, mapping each block to a single vector using learned weights:  
$\hat{k} = \text{cat}(k_1, \dots, k_m) W$  
- **Weight shape**: `(num_heads, kernel_size * head_dim, head_dim)`

---

### Attention Mechanisms

These functions compute attention using either full or sparse mechanisms.

#### `flash_attention_varlen`
A variable-length implementation of flash attention, similar to `flash_attn_varlen_func` from the `flash_attn` package.

**Parameters:**
- `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)
- `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for queries and keys
- `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths in the batch
- `causal`: Apply causal masking (default: `False`)
- `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)

**Returns:**
- Attention output tensor (`total_q_len, num_q_heads, head_dim`)

#### `compressed_attention`
Computes attention between a query and compressed key/value tensors, identifying the top-k blocks for sparse attention.

**Parameters:**
- `q`: Query tensor (`total_len, num_heads, head_dim`)
- `k`, `v`: Compressed key and value tensors (`total_compress_len, num_heads, head_dim`)
- `kernel_size`, `kernel_stride`: Compression parameters
- `block_size`: Size of blocks for sparse attention
- `topk`: Number of top blocks to select
- `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for query and compressed key/value
- `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths for query and compressed key/value
- `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
- `init_blocks`: Number of initial blocks forced to be selected (default: `1`)
- `local_blocks`: Number of local blocks forced to be selected (default: `2`)

**Returns:**
- Tuple containing:
  - Attention output tensor
  - Top-k block indices

#### `topk_sparse_attention`
Performs sparse attention using precomputed top-k block indices. If a query attends to fewer than `topk` key/value blocks, the `topk_idx` should be padded with `-1` on the right.

**Parameters:**
- `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)
- `topk_idx`: Precomputed top-k indices (`num_kv_heads, total_len, topk`)
- `block_size`: Block size for sparse attention (recommended: `64` or `128`)
- `cu_seqlens`: Cumulative sequence lengths
- `softmax_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)

**Returns:**
- Attention output tensor (`total_len, num_q_heads, head_dim`)

---

## Usage Example

Below is a typical workflow demonstrating how to combine these sparse attention functions:

```python
import torch
from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention

# Example input setup
num_q_heads = 64
num_kv_heads = 4
head_dim = 128
cu_seqlens = torch.tensor([0, 1024, 8192, 16384], dtype=torch.int32).cuda()

# Query, key, and value tensors
query = torch.randn(16384, num_q_heads, head_dim, dtype=torch.bfloat16).cuda()
key = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()
value = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()

# Compression weights and positional embeddings
kernel_size = 32
kernel_stride = 16
wk = torch.randn(num_kv_heads, kernel_size * head_dim, head_dim, dtype=torch.bfloat16).cuda()
wv = torch.randn_like(wk)
pe = torch.randn(num_kv_heads, kernel_size, head_dim, dtype=torch.bfloat16).cuda()

# Parameters for top-k sparse attention
block_size = 64
topk = 16

# 1. Compress key and value tensors
compressed_key, compressed_cu_seqlens = linear_compress(
    key, wk, cu_seqlens, kernel_size, kernel_stride, pe
)
compressed_value, _ = linear_compress(
    value, wv, cu_seqlens, kernel_size, kernel_stride, None
)

# 2. Compute attention with compressed key/value and get top-k indices
compressed_attn_output, topk_idx = compressed_attention(
    query,
    compressed_key,
    compressed_value,
    kernel_size,
    kernel_stride,
    block_size,
    topk,
    cu_seqlens,
    compressed_cu_seqlens,
    init_blocks=1,
    local_blocks=2,
)

# 3. Perform top-k sparse attention
sparse_attn_output = topk_sparse_attention(
    query,
    key,
    value,
    topk_idx,
    block_size,
    cu_seqlens,
)

# 4. Combine attention outputs (e.g., average)
attn_output = (compressed_attn_output + sparse_attn_output) / 2
```

For a complete implementation of the Native Sparse Attention module, see `native_sparse_attention/module/native_sparse_attention.py`.


================================================
FILE: native_sparse_attention/ops/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# compress method
from native_sparse_attention.ops.triton.weighted_pool import (
    weightedpool_compress,
    avgpool_compress,
)
from native_sparse_attention.ops.triton.linear_compress import linear_compress

# prefill attention
from native_sparse_attention.ops.triton.flash_attention import flash_attention_varlen
from native_sparse_attention.ops.triton.compressed_attention import compressed_attention
from native_sparse_attention.ops.triton.topk_sparse_attention import (
    topk_sparse_attention,
)

# decode attention
from native_sparse_attention.ops.triton.flash_attention_decode import (
    flash_attention_decode,
)
from native_sparse_attention.ops.torch.compressed_attention_decode import (
    compressed_attention_decode,
)
from native_sparse_attention.ops.triton.topk_sparse_attention_decode import (
    topk_sparse_attention_decode,
)

__all__ = [
    # compress method
    "avgpool_compress",
    "weightedpool_compress",
    "linear_compress",
    # prefill attention, trainable
    "flash_attention_varlen",
    "compressed_attention",
    "topk_sparse_attention",
    # decode attention, no grad
    "flash_attention_decode",
    "compressed_attention_decode",
    "topk_sparse_attention_decode",
]


================================================
FILE: native_sparse_attention/ops/torch/__init__.py
================================================


================================================
FILE: native_sparse_attention/ops/torch/compress_key_value.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional
from einops import rearrange, einsum


def avgpool_compress_torch(
    x: torch.Tensor,
    w: torch.Tensor,
    cu_seqlens,
    kernel_size: int,
    kernel_stride: int,
    pe: Optional[torch.Tensor] = None,
):
    """Compress key and value tensor with kernel_size and kernel_stride.

    Args:
        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
        w (torch.Tensor): no weight for avgpool, must be None.
        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
        kernel_stride (int): stride for each compress kernel
        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
    """
    # dtype check
    assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
    assert cu_seqlens.dtype == torch.int32
    assert x.dtype == pe.dtype if pe is not None else True

    # shape check
    total_len, num_heads, head_dim = x.shape
    batch_size = cu_seqlens.shape[0] - 1
    assert w is None, "don't need additional weight for avgpool"
    assert kernel_size % kernel_stride == 0
    assert kernel_size in {16, 32, 64, 128}

    # compute seqlens after compression
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
    # corner case, if sequence_length < kernel_size, no compression for this sequence
    y_seqlens[seqlens < kernel_size] = 0
    y_cu_seqlens = torch.cat(
        [
            torch.zeros(1, dtype=torch.int32, device="cuda"),
            torch.cumsum(y_seqlens, dim=0),
        ],
        dim=0,
    ).to(torch.int32)

    # pad and rearrange x
    x = rearrange(x, "n h d -> n (h d)")
    splited_x = torch.split(x, seqlens.tolist(), 0)
    x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
    x = rearrange(x, "b n d -> b d n")
    # avgpool
    y = torch.nn.functional.avg_pool1d(x, kernel_size=kernel_size, stride=kernel_stride)
    y = rearrange(y, "b (h d) n -> b n h d", h=num_heads)
    # only keep useful part
    y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)

    # position embedding as a bias
    if pe is not None:
        bias = torch.mean(pe, dim=1)
        y = y + bias.unsqueeze(0)

    return y, y_cu_seqlens


def weightedpool_compress_torch(
    x: torch.Tensor,
    w: torch.Tensor,  # [num_heads, kernel_size]
    cu_seqlens,
    kernel_size: int,
    kernel_stride: int,
    pe: Optional[torch.Tensor] = None,
):
    """Compress key and value tensor with kernel_size and kernel_stride.

    Args:
        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
        w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)
        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
        kernel_stride (int): stride for each compress kernel
        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
    """
    # dtype check
    assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
    assert x.dtype == w.dtype
    assert x.dtype == pe.dtype if pe is not None else True
    assert cu_seqlens.dtype == torch.int32
    # shape check
    total_len, num_heads, head_dim = x.shape
    batch_size = cu_seqlens.shape[0] - 1
    assert w.shape[0] == num_heads
    assert w.shape[1] == kernel_size
    assert kernel_size % kernel_stride == 0
    assert kernel_size in {16, 32, 64, 128}
    # compute seqlens after compression
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
    # corner case, if sequence_length < kernel_size, no compression for this sequence
    y_seqlens[seqlens < kernel_size] = 0
    y_cu_seqlens = torch.cat(
        [
            torch.zeros(1, dtype=torch.int32, device="cuda"),
            torch.cumsum(y_seqlens, dim=0),
        ],
        dim=0,
    ).to(torch.int32)
    # pad and rearrange x
    x = rearrange(x, "n h d -> n (h d)")
    splited_x = torch.split(x, seqlens.tolist(), 0)
    x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
    x = rearrange(x, "b n (h d) -> b h n d", h=num_heads)
    x = x.as_strided(
        size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),
        stride=(
            x.stride(0),
            x.stride(1),
            kernel_stride * x.stride(2),
            x.stride(2),
            x.stride(3),
        ),
    )
    y = einsum(x, w, "b h n k d, h k -> b n h d")
    # only keep useful part
    y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)

    # position embedding as a bias
    if pe is not None:
        bias = einsum(pe, w, "h k d, h k -> h d")
        y = y + bias.unsqueeze(0)

    return y, y_cu_seqlens


def linear_compress_torch(
    x: torch.Tensor,
    w: torch.Tensor,  # [num_heads, kernel_size * head_dim, head_dim]
    cu_seqlens,
    kernel_size: int,
    kernel_stride: int,
    pe: Optional[torch.Tensor] = None,
):
    """Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress.

    Args:
        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
        w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)
        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
        kernel_stride (int): stride for each compress kernel
        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
    """
    # dtype check
    assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
    assert x.dtype == w.dtype
    assert x.dtype == pe.dtype if pe is not None else True
    assert cu_seqlens.dtype == torch.int32
    # shape check
    total_len, num_heads, head_dim = x.shape
    batch_size = cu_seqlens.shape[0] - 1
    assert w.shape[0] == num_heads
    assert w.shape[1] == kernel_size * head_dim
    assert w.shape[2] == head_dim
    assert kernel_size % kernel_stride == 0
    assert kernel_size in {16, 32, 64, 128}
    # compute seqlens after compression
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
    # corner case, if sequence_length < kernel_size, no compression for this sequence
    y_seqlens[seqlens < kernel_size] = 0
    y_cu_seqlens = torch.cat(
        [
            torch.zeros(1, dtype=torch.int32, device="cuda"),
            torch.cumsum(y_seqlens, dim=0),
        ],
        dim=0,
    ).to(torch.int32)
    # pad and rearrange x
    x = rearrange(x, "n h d -> n (h d)")
    splited_x = torch.split(x, seqlens.tolist(), 0)
    x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
    x = rearrange(x, "b n (h d) -> b h n d", h=num_heads)
    x = x.as_strided(
        size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),
        stride=(
            x.stride(0),
            x.stride(1),
            kernel_stride * x.stride(2),
            x.stride(2),
            x.stride(3),
        ),
    )
    y = einsum(
        x,
        rearrange(w, "h (k d) D -> h k d D", k=kernel_size),
        "b h n k d, h k d D -> b n h D",
    )
    # only keep useful part
    y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)

    # position embedding as a bias
    if pe is not None:
        pe = rearrange(pe, "h k d -> h (k d)")
        bias = einsum(pe, w, "h D, h D d -> h d")
        y = y + bias.unsqueeze(0)

    return y, y_cu_seqlens


================================================
FILE: native_sparse_attention/ops/torch/compressed_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Tuple
from collections import Counter
from einops import rearrange


def transform_score(
    score: torch.Tensor,
    kernel_size: int,
    kernel_stride: int,
    block_size: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    init_blocks: int = 1,
    local_blocks: int = 2,
) -> torch.Tensor:
    num_k_heads, total_query_len, _ = score.shape
    pad_len = kernel_size // kernel_stride - 1
    score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)
    max_blocks = math.ceil(max_seqlen_q / block_size)
    full_blocks = max_seqlen_q // block_size
    block_score = torch.zeros(
        num_k_heads,
        total_query_len,
        max_blocks,
        dtype=torch.float32,
        device=score.device,
    )
    offs = (
        torch.arange(kernel_size // kernel_stride)[:, None]
        + torch.arange(block_size // kernel_stride)[None, :]
    ).view(-1)
    offs = dict(Counter(offs.tolist()))
    for k, v in offs.items():
        block_score[..., :full_blocks] += (
            v * score[..., k :: block_size // kernel_stride][..., :full_blocks]
        )
    # set init block and local block score
    batch_size = cu_seqlens_q.shape[0] - 1
    q_idx = torch.cat(
        [
            torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=score.device)
            for i in range(batch_size)
        ],
        dim=0,
    )
    q_idx = q_idx // block_size
    b_idx = torch.arange(max_blocks, device=score.device)
    block_score[..., :init_blocks] = torch.inf
    local_mask = (q_idx[:, None] >= b_idx[None, :]) & (
        q_idx[:, None] < b_idx[None, :] + local_blocks
    )
    local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)
    block_score[local_mask] = torch.inf
    return block_score


def compressed_attention_torch(
    q: torch.Tensor,  # [total_query_len, num_q_heads, head_dim]
    k: torch.Tensor,  # [total_key_len, num_k_heads, head_dim]
    v: torch.Tensor,  # [total_key_len, num_k_heads, head_dim]
    kernel_size: int,
    kernel_stride: int,
    block_size: int,
    topk: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    sm_scale: float = None,
    init_blocks: int = 1,
    local_blocks: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Attention between query and compressed key and value. Implemented with torch, only for debug.

    Args:
        q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
        k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
        v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
        kernel_size (int): kernel size in compress_key_value
        kernel_stride (int): stride of compress_key_value
        block_size (int): key value block size for topk sparse attention.
        topk (int): number of blocks for each query.
        cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
        cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
        max_seqlen_q (int): max q len of the batch.
        max_seqlen_k (int): max k len of the batch.
        sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
        init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
        local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
    """
    assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0
    total_query_len, num_q_heads, head_dim = q.shape
    total_key_len, num_k_heads, _ = k.shape
    num_share_q_heads = num_q_heads // num_k_heads
    batch_size = cu_seqlens_q.shape[0] - 1
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(head_dim)
    # get mask
    mask = torch.zeros(
        (total_query_len, total_key_len), dtype=torch.bool, device=q.device
    )
    for b in range(batch_size):
        q_len, k_len = (
            cu_seqlens_q[b + 1] - cu_seqlens_q[b],
            cu_seqlens_k[b + 1] - cu_seqlens_k[b],
        )
        k_max_ids = (
            torch.arange(k_len, device=q.device) * kernel_stride + kernel_size - 1
        )
        q_ids = torch.arange(q_len, device=q.device)
        mask[
            cu_seqlens_q[b] : cu_seqlens_q[b + 1], cu_seqlens_k[b] : cu_seqlens_k[b + 1]
        ] = (q_ids[:, None] >= k_max_ids[None, :])
    # attention
    qk = (
        torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1))
        * sm_scale
    )
    qk = qk.masked_fill_(~mask[None, ...], -torch.inf)
    # query from beginning of the sequence can't attend to any compressed key
    qk = qk.softmax(dim=-1, dtype=torch.float32)
    qk = qk.nan_to_num(0)
    attn_output = torch.einsum(
        "hqk,khd->qhd", qk.to(v.dtype), v.repeat_interleave(num_share_q_heads, 1)
    )
    with torch.no_grad():
        # get avg score over gqa heads
        # qk shape [num_k_heads, total_q_len, total_k_len]
        score = torch.zeros(
            num_k_heads,
            cu_seqlens_q[-1],
            max_seqlen_k,
            dtype=torch.float32,
            device=q.device,
        )
        qk = rearrange(qk, "(h g) q k -> h g q k", h=num_k_heads).sum(1)
        for b in range(batch_size):
            score[
                :,
                cu_seqlens_q[b] : cu_seqlens_q[b + 1],
                : cu_seqlens_k[b + 1] - cu_seqlens_k[b],
            ] = qk[
                :,
                cu_seqlens_q[b] : cu_seqlens_q[b + 1],
                cu_seqlens_k[b] : cu_seqlens_k[b + 1],
            ]
        # transform score to block-wise score
        score = transform_score(
            score,
            kernel_size,
            kernel_stride,
            block_size,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            init_blocks,
            local_blocks,
        )
        # get topk
        batch_size = cu_seqlens_q.shape[0] - 1
        q_idx = torch.cat(
            [
                torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
                for i in range(batch_size)
            ],
            dim=0,
        )
        q_idx = q_idx // block_size
        topk = min(topk, score.shape[-1])
        topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
        topk_idx[topk_idx > q_idx[None, :, None]] = -1
        topk_idx = topk_idx.to(torch.int32)
    return attn_output, topk_idx


================================================
FILE: native_sparse_attention/ops/torch/compressed_attention_decode.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Tuple, Optional
from collections import Counter
from einops import rearrange


def transform_score(
    score: torch.Tensor,
    seqlens: torch.Tensor,
    kernel_size: int,
    kernel_stride: int,
    block_size: int,
    init_blocks: int = 1,
    local_blocks: int = 2,
) -> torch.Tensor:
    num_k_heads, batch_size, kv_len = score.shape
    pad_len = kernel_size // kernel_stride - 1
    score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)
    max_seqlen = seqlens.max().item()
    max_blocks = math.ceil(max_seqlen / block_size)
    full_blocks = max_seqlen // block_size
    block_score = torch.zeros(
        num_k_heads,
        batch_size,
        max_blocks,
        dtype=torch.float32,
        device=score.device,
    )
    offs = (
        torch.arange(kernel_size // kernel_stride)[:, None]
        + torch.arange(block_size // kernel_stride)[None, :]
    ).view(-1)
    offs = dict(Counter(offs.tolist()))
    for k, v in offs.items():
        block_score[..., :full_blocks] += (
            v * score[..., k :: block_size // kernel_stride][..., :full_blocks]
        )
    # set init block and local block score
    q_idx = (seqlens - 1) // block_size
    b_idx = torch.arange(max_blocks, device=score.device)
    block_score[..., :init_blocks] = torch.inf
    local_mask = (q_idx[:, None] >= b_idx[None, :]) & (
        q_idx[:, None] < b_idx[None, :] + local_blocks
    )
    local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)
    block_score[local_mask] = torch.inf
    block_score = block_score.nan_to_num(0, torch.inf, -torch.inf)
    return block_score


def compressed_attention_decode(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    seqlens: torch.Tensor,
    compress_seqlens: torch.Tensor,
    kernel_size: int,
    kernel_stride: int,
    block_size: int,
    topk: int,
    init_blocks: int = 1,
    local_blocks: int = 2,
    sm_scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """_summary_

    Args:
        q (torch.Tensor): shape [batch_size, num_q_heads, head_dim]
        k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
        v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
        seqlens (torch.Tensor): original kv length for each sequence
        compress_seqlens (torch.Tensor): kv length for each sequence after compression
        kernel_size (int): kernel size in compress_key_value
        kernel_stride (int): stride of compress_key_value
        block_size (int): key value block size for topk sparse attention.
        topk (int): number of blocks for each query.
        init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
        local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
        sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention_decode
    """
    assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0
    batch_size, num_q_heads, head_dim = q.shape
    batch_size, kv_len, num_k_heads, _ = k.shape
    num_share_q_heads = num_q_heads // num_k_heads
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(head_dim)
    # input is too short to have a valid block
    if kv_len == 0:
        return torch.zeros_like(q), torch.zeros(
            num_k_heads, batch_size, 1, device=q.device, dtype=torch.int32
        )
    # get mask
    mask = (
        compress_seqlens[:, None]
        > torch.arange(
            kv_len, device=compress_seqlens.device, dtype=compress_seqlens.dtype
        )[None, :]
    )
    # attention
    qk = (
        torch.einsum(
            "bihgd, bjhgd -> bhgij",
            rearrange(q, "b (h g) d -> b 1 h g d", g=num_share_q_heads),
            rearrange(k, "b j h d -> b j h 1 d"),
        )
        * sm_scale
    )
    qk = qk.masked_fill_(~mask[:, None, None, None, :], -torch.inf)
    qk = qk.softmax(dim=-1, dtype=torch.float32)
    qk = qk.nan_to_num_(0)  # qk is nan when seqlen == 0
    attn_output = torch.einsum(
        "bhgij, bjhgd -> bihgd",
        qk.to(v.dtype),
        rearrange(v, "b k h d -> b k h 1 d"),
    )
    attn_output = rearrange(attn_output, "b 1 h g d -> b (h g) d")

    # get score
    score = rearrange(qk.sum(2).squeeze(2), "b h j -> h b j")
    # transform score to block-wise score
    score = transform_score(
        score,
        seqlens,
        kernel_size,
        kernel_stride,
        block_size,
        init_blocks,
        local_blocks,
    )
    # get topk
    q_idx = (seqlens - 1) // block_size
    topk = min(topk, score.shape[-1])
    topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
    topk_idx[topk_idx > q_idx[None, :, None]] = -1
    topk_idx = topk_idx.to(torch.int32)
    return attn_output, topk_idx


================================================
FILE: native_sparse_attention/ops/torch/topk_sparse_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Optional


def topk_sparse_attention_torch(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    topk_idx: torch.Tensor,
    block_size_k: int,
    cu_seqlens: torch.Tensor,
    softmax_scale: Optional[float] = None,
    block_size_q: int = 1,
) -> torch.Tensor:
    """Simple topk sparse attention varlen version implemented in torch. Extremly slow, only for debugging.

    Args:
        q (torch.Tensor): shape [total_len, num_q_heads, head_dim]
        k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
        v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
        topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.
        block_size_q (int): query block size.
        block_size_k (int): key value block size.
        cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
        softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).

    Returns:
        torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]
    """
    total_seqlen, num_q_heads, head_dim = q.shape
    total_seqlen, num_kv_heads, head_dim = k.shape
    num_share_q_heads = num_q_heads // num_kv_heads
    batch_size = cu_seqlens.shape[0] - 1
    topk = topk_idx.shape[-1]
    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
    seqblocks_q = torch.ceil(seqlens / block_size_q).to(torch.int32)
    cu_seqblocks_q = torch.nn.functional.pad(seqblocks_q.cumsum(0), (1, 0), value=0)
    if softmax_scale is None:
        softmax_scale = 1.0 / math.sqrt(head_dim)
    # get mask
    mask = torch.zeros(
        (num_kv_heads, total_seqlen, total_seqlen), dtype=torch.bool, device=q.device
    )
    for i in range(batch_size):
        num_q_blocks = math.ceil(seqlens[i] / block_size_q)
        num_kv_blocks = math.ceil(seqlens[i] / block_size_k)
        for h in range(num_kv_heads):
            temp_mask = torch.zeros(
                num_q_blocks, num_kv_blocks, dtype=torch.bool, device=q.device
            )
            temp_idx = topk_idx[h, cu_seqblocks_q[i] : cu_seqblocks_q[i + 1]].clone()
            temp_idx[temp_idx < 0] = 0
            temp_mask[torch.arange(num_q_blocks).to(q.device)[:, None], temp_idx] = True
            temp_mask = torch.repeat_interleave(temp_mask, block_size_q, dim=0)
            temp_mask = torch.repeat_interleave(temp_mask, block_size_k, dim=1)
            temp_mask = temp_mask[: seqlens[i], : seqlens[i]]
            mask[
                h, cu_seqlens[i] : cu_seqlens[i + 1], cu_seqlens[i] : cu_seqlens[i + 1]
            ] = temp_mask
    mask = torch.tril(mask).repeat_interleave(num_share_q_heads, 0)
    # qk attn
    qk = (
        torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1))
        * softmax_scale
    )
    qk = torch.masked_fill(qk, ~mask, -torch.inf)
    qk = torch.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)
    o = torch.einsum("hqk,khd->qhd", qk, v.repeat_interleave(num_share_q_heads, 1))
    return o


================================================
FILE: native_sparse_attention/ops/triton/__init__.py
================================================


================================================
FILE: native_sparse_attention/ops/triton/compressed_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Tuple, Union
from collections import Counter
import torch
import triton
import triton.language as tl
import warnings
from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu


IS_HOPPER_GPU = is_hopper_gpu()


@triton.jit
def forward_kernel(
    q_ptr,  # Q: n x h x d
    k_ptr,  # K: n x h x d
    v_ptr,  # V: n x h x d
    o_ptr,  # O: n x h x d
    lse_ptr,  # LSE: h x n
    # size and stride at compresstion
    kernel_size,
    kernel_stride,
    # seqlens
    cu_seqlens_q,
    cu_seqlens_k,
    # shape
    NUM_KV_HEADS,
    NUM_SHARE_Q_HEADS,
    HEAD_DIM,
    # sm_scale
    sm_scale,
    # stride
    stride_qn,
    stride_qh,
    stride_qd,
    stride_kn,
    stride_kh,
    stride_kd,
    stride_vn,
    stride_vh,
    stride_vd,
    stride_on,
    stride_oh,
    stride_od,
    stride_lh,
    stride_ln,
    # META parameters
    BLOCK_SIZE_Q: tl.constexpr,  # q block size
    BLOCK_SIZE_K: tl.constexpr,  # k block size
    BLOCK_SIZE_D: tl.constexpr,
):
    qk_scale = sm_scale * 1.44269504
    # get batch id and head id
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    pid_q = tl.program_id(2)
    pid_kh = pid_h // NUM_SHARE_Q_HEADS
    # get q k start and len after rmpad
    q_start = tl.load(cu_seqlens_q + pid_b)
    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
    k_start = tl.load(cu_seqlens_k + pid_b)
    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
    # skip first kernel_size query block, because they do no attend to any keys
    q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
    if q_start_in_seq >= q_len:
        return
    # init qkv pointer
    q_ptrs = tl.make_block_ptr(
        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
        shape=(q_len, HEAD_DIM),
        strides=(stride_qn, stride_qd),
        offsets=(q_start_in_seq, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
        order=(1, 0),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
        shape=(HEAD_DIM, k_len),
        strides=(stride_kd, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
        order=(0, 1),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
        shape=(k_len, HEAD_DIM),
        strides=(stride_vn, stride_vd),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
        order=(1, 0),
    )
    # load q
    q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
    # init statistics
    off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
    off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
    m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
    lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
    acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
    # attention
    lo = 0
    hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
    for i in range(lo, hi, BLOCK_SIZE_K):
        i = tl.multiple_of(i, BLOCK_SIZE_K)
        # load k
        k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
        # compute qk
        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
        qk += tl.where(
            off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
        )
        qk += tl.dot(q, k) * qk_scale
        # compute m_ij and l_ij
        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
        p = tl.exp2(qk - m_ij[:, None])
        l_ij = tl.sum(p, axis=1)
        # scale acc_o
        acc_o_scale = tl.exp2(m_i - m_ij)
        acc_o = acc_o * acc_o_scale[:, None]
        # load v and update acc_o
        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
        p = p.to(v.dtype)
        acc_o += tl.dot(p, v)
        # update statistics
        m_i = m_ij
        lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
        # update ptrs
        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
        v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
    # final scale
    acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
    # save output
    o_ptrs = tl.make_block_ptr(
        base=o_ptr + q_start * stride_on + pid_h * stride_oh,
        shape=(q_len, HEAD_DIM),
        strides=(stride_on, stride_od),
        offsets=(q_start_in_seq, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
        order=(1, 0),
    )
    tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
    # save lse
    l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
    tl.store(l_ptrs, lse_i, mask=off_q < q_len)


@triton.jit
def backward_sum_o_do(
    o_ptr,  # O: n x h x d
    do_ptr,  # dO: n x h x d
    delta_ptr,  # D: h x n
    o_len,
    HEAD_DIM,
    stride_on,
    stride_oh,
    stride_od,
    stride_don,
    stride_doh,
    stride_dod,
    stride_dh,
    stride_dn,
    BLOCK_SIZE_O: tl.constexpr,
    BLOCK_SIZE_D: tl.constexpr,
):
    pid_n = tl.program_id(0)
    pid_h = tl.program_id(1)
    off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
    off_d = tl.arange(0, BLOCK_SIZE_D)
    o = tl.load(
        o_ptr
        + off_n[:, None] * stride_on
        + pid_h * stride_oh
        + off_d[None, :] * stride_od,
        mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
        other=0,
    ).to(tl.float32)
    do = tl.load(
        do_ptr
        + off_n[:, None] * stride_don
        + pid_h * stride_doh
        + off_d[None, :] * stride_dod,
        mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
        other=0,
    ).to(tl.float32)
    delta = tl.sum(o * do, axis=1)
    tl.store(
        delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
    )


@triton.jit
def backward_dkdv(
    q_ptr,  # Q: n x qh x d
    k_ptr,  # K: n x kh x d
    v_ptr,  # V: n x kh x d
    lse_ptr,  # LSE: qh x n
    d_ptr,  # Delta: qh x n
    do_ptr,
    dk_ptr,  # DK: sh x n x kh x d
    dv_ptr,  # DV: sh x n x kh x d
    kernel_size,
    kernel_stride,
    # seqlens
    cu_seqlens_q,
    cu_seqlens_k,
    # shape
    NUM_KV_HEADS,
    NUM_SHARE_Q_HEADS,
    HEAD_DIM,
    # sm_scale
    sm_scale,
    # stride
    stride_qn,
    stride_qh,
    stride_qd,
    stride_kn,
    stride_kh,
    stride_kd,
    stride_vn,
    stride_vh,
    stride_vd,
    stride_lh,
    stride_ln,
    stride_dh,
    stride_dn,
    stride_don,
    stride_doh,
    stride_dod,
    stride_dks,
    stride_dkn,
    stride_dkh,
    stride_dkd,
    stride_dvs,
    stride_dvn,
    stride_dvh,
    stride_dvd,
    # META parameters
    BLOCK_SIZE_Q: tl.constexpr,  # q block size
    BLOCK_SIZE_K: tl.constexpr,  # k block size
    BLOCK_SIZE_D: tl.constexpr,
):
    qk_scale = sm_scale * 1.44269504
    # get batch id and head id
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    pid_kh = pid_h // NUM_SHARE_Q_HEADS
    pid_sh = pid_h % NUM_SHARE_Q_HEADS
    pid_k = tl.program_id(2)
    # get q k start and len after rmpad
    q_start = tl.load(cu_seqlens_q + pid_b)
    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
    k_start = tl.load(cu_seqlens_k + pid_b)
    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
    if BLOCK_SIZE_K * pid_k >= k_len:
        return
    # init pointers
    k_ptrs = tl.make_block_ptr(
        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
        shape=(k_len, HEAD_DIM),
        strides=(stride_kn, stride_kd),
        offsets=(pid_k * BLOCK_SIZE_K, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
        order=(1, 0),
    )
    dk_ptrs = tl.make_block_ptr(
        base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
        shape=(k_len, HEAD_DIM),
        strides=(stride_dkn, stride_dkd),
        offsets=(pid_k * BLOCK_SIZE_K, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
        order=(1, 0),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
        shape=(k_len, HEAD_DIM),
        strides=(stride_vn, stride_vd),
        offsets=(pid_k * BLOCK_SIZE_K, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
        order=(1, 0),
    )
    dv_ptrs = tl.make_block_ptr(
        base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
        shape=(k_len, HEAD_DIM),
        strides=(stride_dvn, stride_dvd),
        offsets=(pid_k * BLOCK_SIZE_K, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
        order=(1, 0),
    )
    # offsets
    off_q = tl.arange(0, BLOCK_SIZE_Q)
    off_k = (
        pid_k * BLOCK_SIZE_K * kernel_stride
        + tl.arange(0, BLOCK_SIZE_K) * kernel_stride
        + kernel_size
        - 1
    )
    # load k v and keep in SRAM
    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
    # init dk dv
    dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
    dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
    q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
    q_ptrs = tl.make_block_ptr(
        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
        shape=(HEAD_DIM, q_len),
        strides=(stride_qd, stride_qn),
        offsets=(0, q_lo),
        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
        order=(0, 1),
    )
    do_ptrs = tl.make_block_ptr(
        base=do_ptr + q_start * stride_don + pid_h * stride_doh,
        shape=(HEAD_DIM, q_len),
        strides=(stride_dod, stride_don),
        offsets=(0, q_lo),
        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
        order=(0, 1),
    )
    d_ptrs = tl.make_block_ptr(
        base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
        shape=(1, q_len),
        strides=(0, stride_dn),
        offsets=(0, q_lo),
        block_shape=(1, BLOCK_SIZE_Q),
        order=(1, 0),
    )
    lse_ptrs = tl.make_block_ptr(
        base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
        shape=(1, q_len),
        strides=(0, stride_ln),
        offsets=(0, q_lo),
        block_shape=(1, BLOCK_SIZE_Q),
        order=(0, 1),
    )
    # loop for q blocks
    for i in range(q_lo, q_len, BLOCK_SIZE_Q):
        # load
        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
        do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
        lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
        d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
        # compute qk
        # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
        qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
        qk += tl.dot(k, q) * qk_scale
        # compute p, ds
        # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
        p = tl.exp2(qk - lse)
        # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
        dp = tl.dot(v, do)
        ds = sm_scale * p * (dp - d)
        # cast dtype
        p = p.to(do.dtype)
        ds = ds.to(q.dtype)
        # update dk and dv
        # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
        dk += tl.dot(ds, tl.trans(q))
        dv += tl.dot(p, tl.trans(do))
        # increment pointers
        q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
        do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
        lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
        d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
    # save dk dv
    tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
    tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))


@triton.jit
def backward_dq(
    q_ptr,  # Q: n x qh x d
    k_ptr,  # K: n x kh x d
    v_ptr,  # V: n x kh x d
    lse_ptr,  # LSE: qh x n
    d_ptr,  # Delta: qh x n
    do_ptr,
    dq_ptr,
    kernel_size,
    kernel_stride,
    # seqlens
    cu_seqlens_q,
    cu_seqlens_k,
    # shape
    NUM_KV_HEADS,
    NUM_SHARE_Q_HEADS,
    HEAD_DIM,
    # sm_scale
    sm_scale,
    # stride
    stride_qn,
    stride_qh,
    stride_qd,
    stride_kn,
    stride_kh,
    stride_kd,
    stride_vn,
    stride_vh,
    stride_vd,
    stride_lh,
    stride_ln,
    stride_dh,
    stride_dn,
    stride_don,
    stride_doh,
    stride_dod,
    stride_dqn,
    stride_dqh,
    stride_dqd,
    # META parameters
    BLOCK_SIZE_Q: tl.constexpr,  # q block size
    BLOCK_SIZE_K: tl.constexpr,  # k block size
    BLOCK_SIZE_D: tl.constexpr,
):
    qk_scale = sm_scale * 1.44269504
    # get batch id and head id
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    pid_q = tl.program_id(2)
    pid_kh = pid_h // NUM_SHARE_Q_HEADS
    # get q k start and len after rmpad
    q_start = tl.load(cu_seqlens_q + pid_b)
    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
    k_start = tl.load(cu_seqlens_k + pid_b)
    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
    # skip first kernel_size query block, because they do no attend to any keys
    q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
    if q_start_in_seq >= q_len:
        return
    # init pointers
    q_ptrs = tl.make_block_ptr(
        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
        shape=(q_len, HEAD_DIM),
        strides=(stride_qn, stride_qd),
        offsets=(q_start_in_seq, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
        order=(1, 0),
    )
    dq_ptrs = tl.make_block_ptr(
        base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
        shape=(q_len, HEAD_DIM),
        strides=(stride_dqn, stride_dqd),
        offsets=(q_start_in_seq, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
        order=(1, 0),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
        shape=(k_len, HEAD_DIM),
        strides=(stride_kn, stride_kd),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
        order=(1, 0),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
        shape=(HEAD_DIM, k_len),
        strides=(stride_vd, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
        order=(0, 1),
    )
    do_ptrs = tl.make_block_ptr(
        base=do_ptr + q_start * stride_don + pid_h * stride_doh,
        shape=(q_len, HEAD_DIM),
        strides=(stride_don, stride_dod),
        offsets=(q_start_in_seq, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
        order=(1, 0),
    )
    d_ptrs = tl.make_block_ptr(
        base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
        shape=(q_len, 1),
        strides=(stride_dn, stride_dh),
        offsets=(q_start_in_seq, 0),
        block_shape=(BLOCK_SIZE_Q, 1),
        order=(0, 1),
    )
    lse_ptrs = tl.make_block_ptr(
        base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
        shape=(q_len, 1),
        strides=(stride_ln, stride_lh),
        offsets=(q_start_in_seq, 0),
        block_shape=(BLOCK_SIZE_Q, 1),
        order=(0, 1),
    )
    # offsets
    off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
    off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
    # load q, do, lse, delta, and keep in SRAM
    q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
    do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
    lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
    d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
    # init dq
    dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
    lo = 0
    hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
    for i in range(lo, hi, BLOCK_SIZE_K):
        # load
        k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
        # compute qk
        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
        qk += tl.where(
            off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
        )
        qk += tl.dot(q, tl.trans(k)) * qk_scale
        # compute p, ds
        p = tl.exp2(qk - lse)
        dp = tl.dot(do, v)
        ds = sm_scale * p * (dp - d)
        # cast dtype
        ds = ds.to(q.dtype)
        # update dq
        dq += tl.dot(ds, k)
        # increment pointers
        k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
        v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
    # save dq
    tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))


def _compressed_attention_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    kernel_size: int,
    kernel_stride: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    sm_scale: float,
):
    # dtype check
    assert k.dtype == q.dtype and v.dtype == q.dtype
    assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
    # shape
    q_len, num_q_heads, head_dim = q.shape
    k_len, num_k_heads, head_dim = k.shape
    v_len, num_v_heads, head_dim = v.shape
    batch_size = cu_seqlens_q.shape[0] - 1
    assert k_len == v_len and q_len > k_len
    # gqa
    assert num_k_heads == num_v_heads
    assert num_q_heads % num_k_heads == 0
    num_share_q_heads = num_q_heads // num_k_heads
    # output tensor
    o = torch.zeros_like(q)
    lse = torch.full(
        (num_q_heads, q_len),
        fill_value=-torch.inf,
        dtype=torch.float32,
        device=q.device,
    )
    # launch kernel
    grid = lambda META: (
        batch_size,
        num_q_heads,
        triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
    )
    BLOCK_SIZE_Q = 128
    BLOCK_SIZE_K = 128
    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
    forward_kernel[grid](
        q,
        k,
        v,
        o,
        lse,
        kernel_size,
        kernel_stride,
        cu_seqlens_q,
        cu_seqlens_k,
        num_k_heads,
        num_share_q_heads,
        head_dim,
        sm_scale,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        o.stride(0),
        o.stride(1),
        o.stride(2),
        lse.stride(0),
        lse.stride(1),
        BLOCK_SIZE_Q=BLOCK_SIZE_Q,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        BLOCK_SIZE_D=BLOCK_SIZE_D,
        num_warps=num_warps,
        num_stages=num_stages,
    )
    return o, lse


def _compressed_attention_bwd(
    o: torch.Tensor,
    do: torch.Tensor,
    lse: torch.Tensor,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    kernel_size: int,
    kernel_stride: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    sm_scale: float,
):
    q_len, num_q_heads, head_dim = q.shape
    k_len, num_k_heads, head_dim = k.shape
    v_len, num_v_heads, head_dim = v.shape
    o_len, num_o_heads, head_dim = o.shape
    num_share_q_heads = num_q_heads // num_k_heads
    # compute D
    delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
    grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
    BLOCK_SIZE_O = 256
    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
    backward_sum_o_do[grid](
        o,
        do,
        delta,
        o_len,
        head_dim,
        o.stride(0),
        o.stride(1),
        o.stride(2),
        do.stride(0),
        do.stride(1),
        do.stride(2),
        delta.stride(0),
        delta.stride(1),
        BLOCK_SIZE_O=BLOCK_SIZE_O,
        BLOCK_SIZE_D=BLOCK_SIZE_D,
        num_warps=num_warps,
        num_stages=num_stages,
    )
    # compute dk dv
    dk = torch.zeros(
        num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
    )
    dv = torch.zeros(
        num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
    )
    batch_size = cu_seqlens_q.shape[0] - 1
    grid = lambda META: (
        batch_size,
        num_q_heads,
        triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
    )
    BLOCK_SIZE_Q = 64
    BLOCK_SIZE_K = 128
    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
    backward_dkdv[grid](
        q,
        k,
        v,
        lse,
        delta,
        do,
        dk,
        dv,
        kernel_size,
        kernel_stride,
        cu_seqlens_q,
        cu_seqlens_k,
        num_k_heads,
        num_share_q_heads,
        head_dim,
        sm_scale,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        lse.stride(0),
        lse.stride(1),
        delta.stride(0),
        delta.stride(1),
        do.stride(0),
        do.stride(1),
        do.stride(2),
        dk.stride(0),
        dk.stride(1),
        dk.stride(2),
        dk.stride(3),
        dv.stride(0),
        dv.stride(1),
        dv.stride(2),
        dv.stride(3),
        BLOCK_SIZE_Q=BLOCK_SIZE_Q,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        BLOCK_SIZE_D=BLOCK_SIZE_D,
        num_warps=num_warps,
        num_stages=num_stages,
    )
    dk = dk.sum(0)
    dv = dv.sum(0)
    # compute dq
    dq = torch.zeros_like(q)
    grid = lambda META: (
        batch_size,
        num_q_heads,
        triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
    )
    BLOCK_SIZE_Q = 128
    BLOCK_SIZE_K = 64
    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
    backward_dq[grid](
        q,
        k,
        v,
        lse,
        delta,
        do,
        dq,
        kernel_size,
        kernel_stride,
        cu_seqlens_q,
        cu_seqlens_k,
        num_k_heads,
        num_share_q_heads,
        head_dim,
        sm_scale,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        lse.stride(0),
        lse.stride(1),
        delta.stride(0),
        delta.stride(1),
        do.stride(0),
        do.stride(1),
        do.stride(2),
        dq.stride(0),
        dq.stride(1),
        dq.stride(2),
        BLOCK_SIZE_Q=BLOCK_SIZE_Q,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        BLOCK_SIZE_D=BLOCK_SIZE_D,
        num_warps=num_warps,
        num_stages=num_stages,
    )
    return dq, dk, dv


class CompressedAttention(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        kernel_size: int,
        kernel_stride: int,
        cu_seqlens_q: torch.Tensor,
        cu_seqlens_k: torch.Tensor,
        max_seqlen_q: int,
        max_seqlen_k: int,
        sm_scale=None,
    ):
        # dtype check
        assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
        assert q.dtype == k.dtype and k.dtype == v.dtype
        assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
        # softmax scale
        if sm_scale is None:
            sm_scale = 1 / math.sqrt(q.shape[-1])
        o, lse = _compressed_attention_fwd(
            q,
            k,
            v,
            kernel_size,
            kernel_stride,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            sm_scale,
        )
        ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
        ctx.sm_scale = sm_scale
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        ctx.kernel_size = kernel_size
        ctx.kernel_stride = kernel_stride
        return o, lse

    @staticmethod
    def backward(ctx, do: torch.Tensor, *args) -> Any:
        q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
        max_seqlen_q = ctx.max_seqlen_q
        max_seqlen_k = ctx.max_seqlen_k
        sm_scale = ctx.sm_scale
        kernel_size = ctx.kernel_size
        kernel_stride = ctx.kernel_stride
        dq, dk, dv = _compressed_attention_bwd(
            o,
            do,
            lse,
            q,
            k,
            v,
            kernel_size,
            kernel_stride,
            cu_seqlens_q,
            cu_seqlens_k,
            max_seqlen_q,
            max_seqlen_k,
            sm_scale,
        )
        return dq, dk, dv, None, None, None, None, None, None, None


@triton.jit
def score_kernel(
    q_ptr,
    k_ptr,
    lse_ptr,
    s_ptr,
    kernel_size,
    kernel_stride,
    # seqlens
    cu_seqlens_q,
    cu_seqlens_k,
    # shape
    NUM_KV_HEADS,
    NUM_SHARE_Q_HEADS,
    HEAD_DIM,
    # sm_scale
    sm_scale,
    # stride
    stride_qn,
    stride_qh,
    stride_qd,
    stride_kn,
    stride_kh,
    stride_kd,
    stride_lh,
    stride_ln,
    stride_sh,
    stride_sq,
    stride_sk,
    # META parameters
    BLOCK_SIZE_Q: tl.constexpr,  # q block size
    BLOCK_SIZE_K: tl.constexpr,  # k block size
    BLOCK_SIZE_D: tl.constexpr,
):
    qk_scale = sm_scale * 1.44269504
    # get batch id and head id
    pid_bkh = tl.program_id(0)
    pid_b = pid_bkh // NUM_KV_HEADS
    pid_kh = pid_bkh % NUM_KV_HEADS
    pid_q = tl.program_id(1)
    pid_k = tl.program_id(2)
    # get q k start and len after rmpad
    q_start = tl.load(cu_seqlens_q + pid_b)
    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
    k_start = tl.load(cu_seqlens_k + pid_b)
    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
    if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
        return
    # init k pointer and load k
    k_ptrs = tl.make_block_ptr(
        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
        shape=(HEAD_DIM, k_len),
        strides=(stride_kd, stride_kn),
        offsets=(0, pid_k * BLOCK_SIZE_K),
        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
        order=(0, 1),
    )
    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
    # offsets
    off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
    off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
    causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
    # init score
    s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
    # loop over gqa heads
    for h in range(NUM_SHARE_Q_HEADS):
        pid_h = pid_kh * NUM_SHARE_Q_HEADS + h
        q_ptrs = tl.make_block_ptr(
            base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
            shape=(q_len, HEAD_DIM),
            strides=(stride_qn, stride_qd),
            offsets=(pid_q * BLOCK_SIZE_Q, 0),
            block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
            order=(1, 0),
        )
        lse_ptrs = tl.make_block_ptr(
            base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
            shape=(q_len, 1),
            strides=(stride_ln, stride_lh),
            offsets=(pid_q * BLOCK_SIZE_Q, 0),
            block_shape=(BLOCK_SIZE_Q, 1),
            order=(0, 1),
        )
        # load q and lse
        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
        lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
        # compute qk
        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
        qk += tl.dot(q, k) * qk_scale
        # compute score
        s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
    # save output
    s_ptrs = tl.make_block_ptr(
        base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
        shape=(q_len, k_len),
        strides=(stride_sq, stride_sk),
        offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
        order=(1, 0),
    )
    tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))


def _get_attention_score(
    q: torch.Tensor,  # [total_query_len, num_q_heads, head_dim]
    k: torch.Tensor,  # [total_key_len, num_k_heads, head_dim]
    lse: torch.Tensor,  # [num_q_heads, total_query_len]
    kernel_size: int,
    kernel_stride: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    sm_scale: float,
) -> torch.Tensor:
    # dtype check
    assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
    assert q.dtype == k.dtype
    assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
    assert (
        lse.dtype == torch.float32
    )  # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
    # shape
    q_len, num_q_heads, head_dim = q.shape
    k_len, num_k_heads, head_dim = k.shape
    batch_size = cu_seqlens_q.shape[0] - 1
    assert q_len > k_len
    if sm_scale is None:
        sm_scale = 1 / math.sqrt(head_dim)
    # gqa
    assert num_q_heads % num_k_heads == 0
    num_share_q_heads = num_q_heads // num_k_heads
    # init score
    score = torch.zeros(
        num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device
    )
    # launch kernel
    grid = lambda META: (
        batch_size * num_k_heads,
        triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
        triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
    )
    BLOCK_SIZE_Q = 128
    BLOCK_SIZE_K = 128
    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
    score_kernel[grid](
        q,
        k,
        lse,
        score,
        kernel_size,
        kernel_stride,
        cu_seqlens_q,
        cu_seqlens_k,
        num_k_heads,
        num_share_q_heads,
        head_dim,
        sm_scale,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        lse.stride(0),
        lse.stride(1),
        score.stride(0),
        score.stride(1),
        score.stride(2),
        BLOCK_SIZE_Q=BLOCK_SIZE_Q,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        BLOCK_SIZE_D=BLOCK_SIZE_D,
        num_warps=8,
        num_stages=3,
    )
    return score


@triton.jit
def _transform_score_kernel(
    s_ptr,  # score, shape: [num_heads, q_len, k_len]
    bs_ptr,  # block wise score: [num_heads, q_len, num_k_block]
    offs,
    cu_seqlens_q,
    # shape
    num_heads,
    num_offs,
    max_k_len,
    max_blocks,
    pad_len,
    # kernel & block size
    block_size,
    block_stride,  # block_size // kernel_stride
    init_blocks,
    local_blocks,
    # stride
    stride_sh,
    stride_sq,
    stride_sk,
    stride_bsh,
    stride_bsq,
    stride_bsk,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_SIZE_O: tl.constexpr,
):
    pid_bh = tl.program_id(0)
    pid_b = pid_bh // num_heads
    pid_h = pid_bh % num_heads
    pid_q = tl.program_id(1)
    pid_k = tl.program_id(2)
    q_start = tl.load(cu_seqlens_q + pid_b)
    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
    k_start = pid_k * BLOCK_SIZE_K
    if pid_q * BLOCK_SIZE_Q >= q_len:
        return
    # load weight
    off_o = tl.arange(0, BLOCK_SIZE_O)
    w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
    # load score
    off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
    off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
    off_k = off_k[None, :] + off_o[:, None]
    s_ptrs = (
        s_ptr
        + q_start * stride_sq
        + pid_h * stride_sh
        + off_q[:, None, None] * stride_sq
        + off_k[None, :, :] * stride_sk
    )
    # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
    s = tl.load(
        s_ptrs,
        mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
        other=0,
    )
    s = s * w[None, :, None]
    s = tl.sum(s, axis=1)
    # init mask and local mask
    off_bq = off_q // block_size
    off_bk = k_start + tl.arange(0, BLOCK_SIZE_K)
    s = tl.where(
        (
            (off_bq[:, None] >= off_bk[None, :])  # causal mask
            & (off_bq[:, None] < off_bk[None, :] + local_blocks)  # local window
        )
        | (off_bk[None, :] < init_blocks),  # init window
        float("inf"),
        s,
    )
    # store block wise score
    bs_ptrs = (
        bs_ptr
        + q_start * stride_bsq
        + pid_h * stride_bsh
        + off_q[:, None] * stride_bsq
        + off_bk[None, :] * stride_bsk
    )
    tl.store(
        bs_ptrs,
        s,
        mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :],
    )


def transform_score(
    score: torch.Tensor,
    kernel_size: int,
    kernel_stride: int,
    block_size: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    init_blocks: int = 1,
    local_blocks: int = 2,
) -> torch.Tensor:
    num_k_heads, total_query_len, max_key_len = score.shape
    batch_size = cu_seqlens_q.shape[0] - 1
    pad_len = kernel_size // kernel_stride - 1
    max_blocks = math.ceil(max_seqlen_q / block_size)
    block_score = torch.zeros(
        num_k_heads,
        total_query_len,
        max_blocks,
        dtype=torch.float32,
        device=score.device,
    )
    offs = (
        torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
        + torch.arange(block_size // kernel_stride, device=score.device)[None, :]
    ).view(-1)
    offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
    num_offs = int(offs.shape[0])
    BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
    BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
    BLOCK_SIZE_Q = 8
    grid = (
        num_k_heads * batch_size,
        triton.cdiv(total_query_len, BLOCK_SIZE_Q),
        triton.cdiv(max_blocks, BLOCK_SIZE_K),
    )
    _transform_score_kernel[grid](
        score,
        block_score,
        offs,
        cu_seqlens_q,
        num_k_heads,
        offs.shape[0],
        max_key_len,
        max_blocks,
        pad_len,
        block_size,
        block_size // kernel_stride,
        init_blocks,
        local_blocks,
        score.stride(0),
        score.stride(1),
        score.stride(2),
        block_score.stride(0),
        block_score.stride(1),
        block_score.stride(2),
        BLOCK_SIZE_Q=BLOCK_SIZE_Q,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
        BLOCK_SIZE_O=BLOCK_SIZE_O,
        num_warps=8,
        num_stages=3,
    )
    return block_score


def compressed_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    kernel_size: int,
    kernel_stride: int,
    block_size: int,
    topk: int,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int = None,
    max_seqlen_k: int = None,
    sm_scale: float = None,
    init_blocks: int = 1,
    local_blocks: int = 2,
    parallel_topk_compute: Union[str, bool] = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.

    Args:
        q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
        k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
        v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
        kernel_size (int): kernel size in compress_key_value
        kernel_stride (int): stride of compress_key_value
        block_size (int): key value block size for topk sparse attention.
        topk (int): number of blocks for each query.
        cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
        cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
        max_seqlen_q (int): max q len of the batch. Defaults to None, means (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item().
        max_seqlen_k (int): max k len of the batch. Defaults to None, means (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item().
        sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
        init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
        local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
        parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
            We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
    """
    if max_seqlen_q is None:
        max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
    if max_seqlen_k is None:
        max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
    attn_output, lse = CompressedAttention.apply(
        q,
        k,
        v,
        kernel_size,
        kernel_stride,
        cu_seqlens_q,
        cu_seqlens_k,
        max_seqlen_q,
        max_seqlen_k,
        sm_scale,
    )

    # do not select topk index
    if topk <= 0:
        warnings.warn("topk <= 0, returned topk_idx will be None")
        return attn_output, None

    assert topk >= init_blocks + local_blocks
    with torch.no_grad():
        num_k_heads, num_q_heads = k.shape[1], q.shape[1]
        num_shared_q_heads = num_q_heads // num_k_heads
        batch_size = cu_seqlens_q.shape[0] - 1
        q_idx = torch.cat(
            [
                torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
                for i in range(batch_size)
            ],
            dim=0,
        )
        q_idx = q_idx // block_size
        # whether to use parallel version
        if parallel_topk_compute == "auto":
            parallel_topk_compute = cu_seqlens_q[-1] <= 32768
        # parallel version
        if parallel_topk_compute:
            # recompute score
            score = _get_attention_score(
                q,
                k,
                lse,
                kernel_size,
                kernel_stride,
                cu_seqlens_q,
                cu_seqlens_k,
                max_seqlen_q,
                max_seqlen_k,
                sm_scale,
            )
            # transform score to block-wise score
            score = transform_score(
                score,
                kernel_size,
                kernel_stride,
                block_size,
                cu_seqlens_q,
                cu_seqlens_k,
                max_seqlen_q,
                max_seqlen_k,
                init_blocks,
                local_blocks,
            )
            # get topk
            topk = min(topk, score.shape[-1])
            topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
            topk_idx[topk_idx > q_idx[None, :, None]] = -1
            topk_idx = topk_idx.to(torch.int32)
        # non parallel version, avoid some current bugs when sequence length is too long
        # FIXME: need to fix later
        else:
            topk_idx_list = []
            for h in range(num_k_heads):
                # recompute score
                score = _get_attention_score(
                    q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
                    k[:, h : h + 1],
                    lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
                    kernel_size,
                    kernel_stride,
                    cu_seqlens_q,
                    cu_seqlens_k,
                    max_seqlen_q,
                    max_seqlen_k,
                    sm_scale,
                )
                # transform score to block-wise score
                score = transform_score(
                    score,
                    kernel_size,
                    kernel_stride,
                    block_size,
                    cu_seqlens_q,
                    cu_seqlens_k,
                    max_seqlen_q,
                    max_seqlen_k,
                    init_blocks,
                    local_blocks,
                )
                # get topk
                topk = min(topk, score.shape[-1])
                topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
                topk_idx[topk_idx > q_idx[None, :, None]] = -1
                topk_idx = topk_idx.to(torch.int32)
                topk_idx_list.append(topk_idx)
            topk_idx = torch.cat(topk_idx_list, dim=0)
    return attn_output, topk_idx


================================================
FILE: native_sparse_attention/ops/triton/flash_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional

import torch
import triton
import triton.language as tl
from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu


IS_HOPPER_GPU = is_hopper_gpu()


@triton.jit
def forward_kernel(
    q_ptr,  # Q: n x h x d
    k_ptr,  # K: n x h x d
    v_ptr,  # V: n x h x d
    o_ptr,  # O: n x h x d
    lse_ptr,  # LSE: h x n
    # seqlens
    cu_seqlens_q,
    cu_seqlens_k,
    # shape
    NUM_KV_HEADS,
    NUM_SHARE_Q_HEADS,
    qk_head_dim,
    v_head_dim,
    # sm_scale
    sm_scale,
    # causal
    causal,
    # gqa
    gqa_interleave,
    # stride
    stride_qn,
    stride_qh,
    stride_qd,
    stride_kn,
    stride_kh,
    stride_kd,
    stride_vn,
    stride_vh,
    stride_vd,
    stride_on,
    stride_oh,
    stride_od,
    stride_lh,
    stride_ln,
    # META parameters
    BLOCK_SIZE_Q: tl.constexpr,  # q block size
    BLOCK_SIZE_K: tl.constexpr,  # k block size
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
):
    qk_scale = sm_scale * 1.44269504
    # get batch id and head id
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    pid_q = tl.program_id(2)
    if gqa_interleave:
        pid_kh = pid_h % NUM_KV_HEADS
    else:
        pid_kh = pid_h // NUM_SHARE_Q_HEADS
    # get q k start and len after rmpad
    q_start = tl.load(cu_seqlens_q + pid_b)
    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
    k_start = tl.load(cu_seqlens_k + pid_b)
    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
    if BLOCK_SIZE_Q * pid_q >= q_len:
        return
    # init qkv pointer
    q_ptrs = tl.make_block_ptr(
        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
        shape=(q_len, qk_head_dim),
        strides=(stride_qn, stride_qd),
        offsets=(pid_q * BLOCK_SIZE_Q, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
        shape=(qk_head_dim, k_len),
        strides=(stride_kd, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_K),
        order=(0, 1),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
        shape=(k_len, v_head_dim),
        strides=(stride_vn, stride_vd),
        offsets=(0, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    # load q
    q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
    # init statistics
    off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
    off_k = tl.arange(0, BLOCK_SIZE_K)
    m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
    lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
    acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_VD), 0, dtype=tl.float32)
    # full attention or causal attention
    lo = 0
    if causal:
        hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q)
    else:
        hi = k_len
    for i in range(lo, hi, BLOCK_SIZE_K):
        i = tl.multiple_of(i, BLOCK_SIZE_K)
        # load k
        k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
        # compute qk
        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
        if causal:
            qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float("-inf"))
        else:
            qk += tl.where((off_k < k_len - i)[None, :], 0, float("-inf"))
        qk += tl.dot(q, k) * qk_scale
        # compute m_ij and l_ij
        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
        p = tl.math.exp2(qk - m_ij[:, None])
        l_ij = tl.sum(p, axis=1)
        # scale acc_o
        acc_o_scale = tl.math.exp2(m_i - m_ij)
        acc_o = acc_o * acc_o_scale[:, None]
        # load v and update acc_o
        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
        p = p.to(v.dtype)
        acc_o += tl.dot(p, v)
        # update statistics
        m_i = m_ij
        lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij)
        # update ptrs
        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
        v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
    # final scale
    acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None]
    # save output
    o_ptrs = tl.make_block_ptr(
        base=o_ptr + q_start * stride_on + pid_h * stride_oh,
        shape=(q_len, v_head_dim),
        strides=(stride_on, stride_od),
        offsets=(pid_q * BLOCK_SIZE_Q, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
    # save lse
    l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
    tl.store(l_ptrs, lse_i, mask=off_q < q_len)


@triton.jit
def backward_sum_o_do(
    o_ptr,  # O: n x h x d
    do_ptr,  # dO: n x h x d
    delta_ptr,  # D: h x n
    o_len,
    HEAD_DIM,
    stride_on,
    stride_oh,
    stride_od,
    stride_don,
    stride_doh,
    stride_dod,
    stride_dh,
    stride_dn,
    BLOCK_SIZE_O: tl.constexpr,
    BLOCK_SIZE_D: tl.constexpr,
):
    pid_n = tl.program_id(0)
    pid_h = tl.program_id(1)
    off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
    off_d = tl.arange(0, BLOCK_SIZE_D)
    o = tl.load(
        o_ptr
        + off_n[:, None] * stride_on
        + pid_h * stride_oh
        + off_d[None, :] * stride_od,
        mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
        other=0,
    ).to(tl.float32)
    do = tl.load(
        do_ptr
        + off_n[:, None] * stride_don
        + pid_h * stride_doh
        + off_d[None, :] * stride_dod,
        mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
        other=0,
    ).to(tl.float32)
    delta = tl.sum(o * do, axis=1)
    tl.store(
        delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
    )


@triton.jit
def backward_dkdv(
    q_ptr,  # Q: n x qh x d
    k_ptr,  # K: n x kh x d
    v_ptr,  # V: n x kh x d
    lse_ptr,  # LSE: qh x n
    d_ptr,  # Delta: qh x n
    do_ptr,
    dk_ptr,  # DK: sh x n x kh x d
    dv_ptr,  # DV: sh x n x kh x d
    # seqlens
    cu_seqlens_q,
    cu_seqlens_k,
    # shape
    NUM_KV_HEADS,
    NUM_SHARE_Q_HEADS,
    qk_head_dim,
    v_head_dim,
    # sm_scale
    sm_scale,
    # causal
    causal,
    # gqa
    gqa_interleave,
    # stride
    stride_qn,
    stride_qh,
    stride_qd,
    stride_kn,
    stride_kh,
    stride_kd,
    stride_vn,
    stride_vh,
    stride_vd,
    stride_lh,
    stride_ln,
    stride_dh,
    stride_dn,
    stride_don,
    stride_doh,
    stride_dod,
    stride_dks,
    stride_dkn,
    stride_dkh,
    stride_dkd,
    stride_dvs,
    stride_dvn,
    stride_dvh,
    stride_dvd,
    # META parameters
    BLOCK_SIZE_Q: tl.constexpr,  # q block size
    BLOCK_SIZE_K: tl.constexpr,  # k block size
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
):
    qk_scale = sm_scale * 1.44269504
    # get batch id and head id
    pid_b = tl.program_id(0)
    pid_h = tl.program_id(1)
    if gqa_interleave:
        pid_kh = pid_h % NUM_SHARE_Q_HEADS
        pid_sh = pid_h // NUM_SHARE_Q_HEADS
    else:
        pid_kh = pid_h // NUM_SHARE_Q_HEADS
        pid_sh = pid_h % NUM_SHARE_Q_HEADS
    pid_k = tl.program_id(2)
    # get q k start and len after rmpad
    q_start = tl.load(cu_seqlens_q + pid_b)
    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
    k_start = tl.load(cu_seqlens_k + pid_b)
    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
    if BLOCK_SIZE_K * pid_k >= k_len:
        return
    # init pointers
    k_ptrs = tl.make_block_ptr(
        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
        shape=(k_len, qk_head_dim),
        strides=(stride_kn, stride_kd),
        offsets=(pid_k * BLOCK_SIZE_K, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    dk_ptrs = tl.make_block_ptr(
        base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
        shape=(k_len, qk_head_dim),
        strides=(stride_dkn, stride_dkd),
        offsets=(pid_k * BLOCK_SIZE_K, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
        shape=(k_len, v_head_dim),
        strides=(stride_vn, stride_vd),
        offsets=(pid_k * BLOCK_SIZE_K, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    dv_ptrs = tl.make_block_ptr(
        base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
        shape=(k_len, v_head_dim),
        strides=(stride_dvn, stride_dvd),
        offsets=(pid_k * BLOCK_SIZE_K, 0),
        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    # offsets
    off_q = tl.arange(0, BLOCK_SIZE_Q)
    off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
    # load k v and keep in SRAM
    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
    # init dk dv
    dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_KD), dtype=tl.float32)
    dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_VD), dtype=tl.float32)
    # causal
    if causal:
        q_lo = pid_k * BLOCK_SIZE_K
    else:
        q_lo = 0
    q_ptrs = tl.make_block_ptr(
        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
        shape=(q_len, qk_head_dim),
        strides=(stride_qn, stride_qd),
        offsets=(q_lo, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    do_ptrs = tl.make_block_ptr(
        base=do_ptr + q_start * stride_don + pid_h * stride_doh,
        shape=(q_len, v_head_dim),
        strides=(stride_don, stride_dod),
        offsets=(q_lo, 0),
        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    d_ptrs = tl.make_block_ptr(
        base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
        shape=(q_len, 1),
        strides=(stride_dn, stride_dh),
        offsets=(q_lo, 0),
        block_shape=(BLOCK_SIZE_Q, 1),
        order=(0, 1),
    )
    lse_ptrs = tl.make_block_ptr(
        base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
        shape=(q_len, 1),
        strides=(stride_ln, stride_lh),
        offsets=(q_lo, 0),
        block_shape=(BLOCK_SIZE_Q, 1),
        order=(0, 1),
    )
    # loop for q blocks
    for i in range(q_lo, q_len, BLOCK_SIZE_Q):
        # load
        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
        do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
        lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
        d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
        # compute qk
        if causal:
            qk = tl.where(
    
Download .txt
gitextract_5vrtmrhk/

├── .gitignore
├── LICENSE
├── README.md
├── install_dependency.sh
├── native_sparse_attention/
│   ├── __init__.py
│   ├── infer/
│   │   ├── __init__.py
│   │   ├── inference_func.py
│   │   └── nsa_inference.py
│   ├── model/
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── toy_llama.py
│   │   └── toy_nsa_llama.py
│   ├── module/
│   │   ├── __init__.py
│   │   ├── kv_cache.py
│   │   ├── native_sparse_attention.py
│   │   ├── rope.py
│   │   └── self_attention.py
│   └── ops/
│       ├── README.md
│       ├── __init__.py
│       ├── torch/
│       │   ├── __init__.py
│       │   ├── compress_key_value.py
│       │   ├── compressed_attention.py
│       │   ├── compressed_attention_decode.py
│       │   └── topk_sparse_attention.py
│       └── triton/
│           ├── __init__.py
│           ├── compressed_attention.py
│           ├── flash_attention.py
│           ├── flash_attention_decode.py
│           ├── linear_compress.py
│           ├── topk_sparse_attention.py
│           ├── topk_sparse_attention_decode.py
│           ├── utils.py
│           └── weighted_pool.py
├── setup.py
└── test/
    ├── test_compress_key_value.py
    ├── test_compressed_attention.py
    ├── test_flash_attention.py
    ├── test_kv_cache.py
    ├── test_linear_compress.py
    ├── test_nsa_infer.py
    ├── test_nsa_model.py
    ├── test_nsa_module.py
    ├── test_rope.py
    └── test_topk_sparse_attention.py
Download .txt
SYMBOL INDEX (157 symbols across 26 files)

FILE: native_sparse_attention/infer/inference_func.py
  function compress_infer (line 27) | def compress_infer(
  function compressed_attention_infer (line 90) | def compressed_attention_infer(
  function topk_sparse_attention_infer (line 148) | def topk_sparse_attention_infer(
  function sliding_window_attention_infer (line 175) | def sliding_window_attention_infer(

FILE: native_sparse_attention/infer/nsa_inference.py
  function nsa_infer (line 24) | def nsa_infer(

FILE: native_sparse_attention/model/toy_llama.py
  class ToyLlamaConfig (line 22) | class ToyLlamaConfig:
  class InferenceConfig (line 47) | class InferenceConfig:
  class RMSNorm (line 53) | class RMSNorm(nn.Module):
    method __init__ (line 54) | def __init__(self, hidden_size: int, eps: float = 1e-6):
    method forward (line 59) | def forward(self, hidden_states: torch.Tensor):
  class FFN (line 67) | class FFN(nn.Module):
    method __init__ (line 68) | def __init__(self, hidden_size: int, intermediate_size: int):
    method forward (line 77) | def forward(self, x):
  class ToyLlamaLayer (line 82) | class ToyLlamaLayer(nn.Module):
    method __init__ (line 83) | def __init__(
    method forward (line 112) | def forward(self, x, cu_seqlens):
    method inference (line 118) | def inference(self, x, cu_seqlens, step, kv_cache):
  class ToyLlama (line 124) | class ToyLlama(nn.Module):
    method __init__ (line 125) | def __init__(
    method forward (line 163) | def forward(
    method inference (line 180) | def inference(
    method generate (line 212) | def generate(

FILE: native_sparse_attention/model/toy_nsa_llama.py
  class ToyNSALlamaConfig (line 22) | class ToyNSALlamaConfig:
  class InferenceConfig (line 56) | class InferenceConfig:
  class RMSNorm (line 62) | class RMSNorm(nn.Module):
    method __init__ (line 63) | def __init__(self, hidden_size: int, eps: float = 1e-6):
    method forward (line 68) | def forward(self, hidden_states: torch.Tensor):
  class FFN (line 76) | class FFN(nn.Module):
    method __init__ (line 77) | def __init__(self, hidden_size: int, intermediate_size: int):
    method forward (line 86) | def forward(self, x):
  class ToyNSALlamaLayer (line 91) | class ToyNSALlamaLayer(nn.Module):
    method __init__ (line 92) | def __init__(
    method forward (line 145) | def forward(self, x, cu_seqlens):
    method inference (line 151) | def inference(self, x, cu_seqlens, step, kv_cache):
  class ToyNSALlama (line 157) | class ToyNSALlama(nn.Module):
    method __init__ (line 158) | def __init__(
    method forward (line 206) | def forward(
    method inference (line 223) | def inference(
    method generate (line 258) | def generate(

FILE: native_sparse_attention/module/kv_cache.py
  class KVCache (line 21) | class KVCache:
    method __init__ (line 22) | def __init__(
    method reset (line 52) | def reset(self):
    method update_kv (line 56) | def update_kv(
    method _update_kv_prefill (line 78) | def _update_kv_prefill(
    method _update_kv_decode (line 104) | def _update_kv_decode(
  class NSACache (line 125) | class NSACache:
    method __init__ (line 144) | def __init__(
    method reset (line 223) | def reset(self):
    method prepare_compress (line 233) | def prepare_compress(
    method _prepare_compress_prefill (line 245) | def _prepare_compress_prefill(
    method _prepare_compress_decode (line 274) | def _prepare_compress_decode(
    method update_kv (line 304) | def update_kv(
    method _update_kv_prefill (line 338) | def _update_kv_prefill(
    method _update_kv_decode (line 401) | def _update_kv_decode(
  function _fill_kv_cache_kernel (line 453) | def _fill_kv_cache_kernel(
  function _fill_kv_cache (line 519) | def _fill_kv_cache(

FILE: native_sparse_attention/module/native_sparse_attention.py
  class NativeSparseAttention (line 45) | class NativeSparseAttention(torch.nn.Module):
    method __init__ (line 65) | def __init__(
    method init_params (line 140) | def init_params(self):
    method forward (line 145) | def forward(
    method inference (line 241) | def inference(

FILE: native_sparse_attention/module/rope.py
  class RopeConfig (line 28) | class RopeConfig:
    method __post_init__ (line 47) | def __post_init__(self):
  function rotate_half (line 53) | def rotate_half(x):
  class RotaryEmbedding (line 62) | class RotaryEmbedding(nn.Module):
    method __init__ (line 73) | def __init__(
    method _dynamic_frequency_update (line 94) | def _dynamic_frequency_update(self, position_ids, device):
    method generate_cos_sin (line 121) | def generate_cos_sin(self, x: torch.Tensor, position_ids):
    method generate_pos_embs (line 158) | def generate_pos_embs(
    method forward (line 198) | def forward(self, x, cu_seqlens, step=0, stride=1):

FILE: native_sparse_attention/module/self_attention.py
  class SelfAttention (line 22) | class SelfAttention(torch.nn.Module):
    method __init__ (line 33) | def __init__(
    method init_params (line 70) | def init_params(self):
    method forward (line 74) | def forward(
    method inference (line 113) | def inference(

FILE: native_sparse_attention/ops/torch/compress_key_value.py
  function avgpool_compress_torch (line 19) | def avgpool_compress_torch(
  function weightedpool_compress_torch (line 84) | def weightedpool_compress_torch(
  function linear_compress_torch (line 156) | def linear_compress_torch(

FILE: native_sparse_attention/ops/torch/compressed_attention.py
  function transform_score (line 21) | def transform_score(
  function compressed_attention_torch (line 74) | def compressed_attention_torch(

FILE: native_sparse_attention/ops/torch/compressed_attention_decode.py
  function transform_score (line 21) | def transform_score(
  function compressed_attention_decode (line 65) | def compressed_attention_decode(

FILE: native_sparse_attention/ops/torch/topk_sparse_attention.py
  function topk_sparse_attention_torch (line 19) | def topk_sparse_attention_torch(

FILE: native_sparse_attention/ops/triton/compressed_attention.py
  function forward_kernel (line 28) | def forward_kernel(
  function backward_sum_o_do (line 162) | def backward_sum_o_do(
  function backward_dkdv (line 206) | def backward_dkdv(
  function backward_dq (line 385) | def backward_dq(
  function _compressed_attention_fwd (line 538) | def _compressed_attention_fwd(
  function _compressed_attention_bwd (line 618) | def _compressed_attention_bwd(
  class CompressedAttention (line 783) | class CompressedAttention(torch.autograd.Function):
    method forward (line 785) | def forward(
    method backward (line 826) | def backward(ctx, do: torch.Tensor, *args) -> Any:
  function score_kernel (line 852) | def score_kernel(
  function _get_attention_score (line 954) | def _get_attention_score(
  function _transform_score_kernel (line 1030) | def _transform_score_kernel(
  function transform_score (line 1116) | def transform_score(
  function compressed_attention (line 1182) | def compressed_attention(

FILE: native_sparse_attention/ops/triton/flash_attention.py
  function forward_kernel (line 27) | def forward_kernel(
  function backward_sum_o_do (line 169) | def backward_sum_o_do(
  function backward_dkdv (line 213) | def backward_dkdv(
  function backward_dq (line 400) | def backward_dq(
  function _flash_attention_fwd (line 563) | def _flash_attention_fwd(
  function _flash_attention_bwd (line 645) | def _flash_attention_bwd(
  class FlashAttention (line 831) | class FlashAttention(torch.autograd.Function):
    method forward (line 833) | def forward(
    method backward (line 870) | def backward(ctx, do: torch.Tensor, *args) -> Any:
  function flash_attention_varlen (line 895) | def flash_attention_varlen(

FILE: native_sparse_attention/ops/triton/flash_attention_decode.py
  function decode_kernel (line 23) | def decode_kernel(
  function flash_attention_decode (line 142) | def flash_attention_decode(
  function torch_attention_decode (line 220) | def torch_attention_decode(

FILE: native_sparse_attention/ops/triton/linear_compress.py
  function linear_compress_fwd_kernel (line 27) | def linear_compress_fwd_kernel(
  function linear_compress_bwd_kernel (line 140) | def linear_compress_bwd_kernel(
  class LinearCompress (line 309) | class LinearCompress(torch.autograd.Function):
    method forward (line 311) | def forward(
    method backward (line 420) | def backward(ctx, dy: torch.Tensor, *args) -> Any:
  function linear_compress (line 497) | def linear_compress(

FILE: native_sparse_attention/ops/triton/topk_sparse_attention.py
  function forward_kernel (line 27) | def forward_kernel(
  function backward_sum_o_do (line 177) | def backward_sum_o_do(
  function count_kernel (line 221) | def count_kernel(
  function count_query (line 267) | def count_query(
  function pad_topk_idx_kernel (line 305) | def pad_topk_idx_kernel(
  function save_topk_idx_kernel (line 351) | def save_topk_idx_kernel(
  function reorder_topk_idx (line 404) | def reorder_topk_idx(
  function backward_dkdv (line 481) | def backward_dkdv(
  function backward_dq (line 659) | def backward_dq(
  function _topk_sparse_attention_fwd (line 828) | def _topk_sparse_attention_fwd(
  function _topk_sparse_attention_bwd (line 912) | def _topk_sparse_attention_bwd(
  class TopkSparseAttention (line 1112) | class TopkSparseAttention(torch.autograd.Function):
    method forward (line 1114) | def forward(
    method backward (line 1156) | def backward(ctx, do: torch.Tensor, *args) -> Any:
  function topk_sparse_attention (line 1182) | def topk_sparse_attention(

FILE: native_sparse_attention/ops/triton/topk_sparse_attention_decode.py
  function forward_kernel (line 23) | def forward_kernel(
  function topk_sparse_attention_decode (line 151) | def topk_sparse_attention_decode(
  function torch_topk_sparse_attention_decode (line 240) | def torch_topk_sparse_attention_decode(
  function generate_topk_idx_example (line 301) | def generate_topk_idx_example(

FILE: native_sparse_attention/ops/triton/utils.py
  function is_hopper_gpu (line 17) | def is_hopper_gpu():
  function get_compressed_seqlens (line 25) | def get_compressed_seqlens(
  function get_num_warps_stages (line 40) | def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):

FILE: native_sparse_attention/ops/triton/weighted_pool.py
  function sliding_pool_fwd_kernel (line 24) | def sliding_pool_fwd_kernel(
  function sliding_pool_dxdw_kernel (line 89) | def sliding_pool_dxdw_kernel(
  class SlidingWindowWeightedPool (line 182) | class SlidingWindowWeightedPool(torch.autograd.Function):
    method forward (line 184) | def forward(
    method backward (line 245) | def backward(ctx, dy, _):
  function weightedpool_compress (line 301) | def weightedpool_compress(
  function avgpool_compress (line 333) | def avgpool_compress(

FILE: test/test_compress_key_value.py
  function benchmark (line 81) | def benchmark(N, H, D, provider):

FILE: test/test_compressed_attention.py
  function benchmark (line 171) | def benchmark(N, H, D, provider):
  function benchmark (line 267) | def benchmark(N, H, D, provider):

FILE: test/test_flash_attention.py
  function benchmark (line 124) | def benchmark(N, H, D, provider):
  function benchmark (line 176) | def benchmark(N, H, D, provider):

FILE: test/test_linear_compress.py
  function test_linear_compress (line 21) | def test_linear_compress(
  function benchmark_fwdbwd (line 220) | def benchmark_fwdbwd(N, H, D, provider):

FILE: test/test_nsa_module.py
  function benchmark (line 121) | def benchmark(N, provider):
  function benchmark (line 153) | def benchmark(N, provider):

FILE: test/test_topk_sparse_attention.py
  function generate_topk_idx_example (line 34) | def generate_topk_idx_example(
  function benchmark (line 174) | def benchmark(N, H, D, K, provider):
  function benchmark (line 245) | def benchmark(N, H, D, K, provider):
Condensed preview — 44 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (383K chars).
[
  {
    "path": ".gitignore",
    "chars": 3415,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 10986,
    "preview": "<div align=\"center\">\n\n# Native Sparse Attention Triton\n\n</div>\n\nThis repository implements the sparse attention mechanis"
  },
  {
    "path": "install_dependency.sh",
    "chars": 441,
    "preview": "pip3 install packaging -i https://pypi.org/simple\npip3 install numpy==1.26.4 -i https://pypi.org/simple\npip3 install tor"
  },
  {
    "path": "native_sparse_attention/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "native_sparse_attention/infer/__init__.py",
    "chars": 688,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/infer/inference_func.py",
    "chars": 5873,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/infer/nsa_inference.py",
    "chars": 5732,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/model/README.md",
    "chars": 4623,
    "preview": "# Guide for the ToyNSALlama Model\n\nThe `ToyNSALlama` model is a custom implementation of a Llama-like transformer archit"
  },
  {
    "path": "native_sparse_attention/model/__init__.py",
    "chars": 941,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/model/toy_llama.py",
    "chars": 9744,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/model/toy_nsa_llama.py",
    "chars": 11688,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/module/__init__.py",
    "chars": 1018,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/module/kv_cache.py",
    "chars": 18667,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/module/native_sparse_attention.py",
    "chars": 9898,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/module/rope.py",
    "chars": 8066,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserve"
  },
  {
    "path": "native_sparse_attention/module/self_attention.py",
    "chars": 5866,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/README.md",
    "chars": 6333,
    "preview": "# Triton Functions for Native Sparse Attention\n\nThis folder provides efficient Triton-based implementations of component"
  },
  {
    "path": "native_sparse_attention/ops/__init__.py",
    "chars": 1812,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/torch/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "native_sparse_attention/ops/torch/compress_key_value.py",
    "chars": 9107,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/torch/compressed_attention.py",
    "chars": 7313,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/torch/compressed_attention_decode.py",
    "chars": 5563,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/torch/topk_sparse_attention.py",
    "chars": 3699,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/triton/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "native_sparse_attention/ops/triton/compressed_attention.py",
    "chars": 41639,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/triton/flash_attention.py",
    "chars": 27720,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/triton/flash_attention_decode.py",
    "chars": 9380,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/triton/linear_compress.py",
    "chars": 18875,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/triton/topk_sparse_attention.py",
    "chars": 38430,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/triton/topk_sparse_attention_decode.py",
    "chars": 11940,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/triton/utils.py",
    "chars": 2667,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "native_sparse_attention/ops/triton/weighted_pool.py",
    "chars": 13525,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "setup.py",
    "chars": 1630,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_compress_key_value.py",
    "chars": 3944,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_compressed_attention.py",
    "chars": 11305,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_flash_attention.py",
    "chars": 7937,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_kv_cache.py",
    "chars": 1822,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_linear_compress.py",
    "chars": 8767,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_nsa_infer.py",
    "chars": 3454,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_nsa_model.py",
    "chars": 2502,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_nsa_module.py",
    "chars": 5937,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_rope.py",
    "chars": 1537,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  },
  {
    "path": "test/test_topk_sparse_attention.py",
    "chars": 11801,
    "preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
  }
]

About this extraction

This page contains the full source code of the XunhaoLai/native-sparse-attention-triton GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 44 files (359.0 KB), approximately 101.9k tokens, and a symbol index with 157 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!