main 9bfd8ed60359 cached
13 files
94.1 KB
23.6k tokens
89 symbols
1 requests
Download .txt
Repository: changjonathanc/flex-nano-vllm
Branch: main
Commit: 9bfd8ed60359
Files: 13
Total size: 94.1 KB

Directory structure:
gitextract_2vld2n29/

├── .gitignore
├── LICENSE
├── README.md
├── THIRD_PARTY_LICENSES.md
├── benchmark.py
├── benchmark_vllm.py
├── flex_nano_vllm/
│   ├── __init__.py
│   ├── inference.py
│   ├── modeling_gemma2.py
│   └── paged_attention.py
├── plot_metrics.py
├── pyproject.toml
└── visualize.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__/
*.egg-info/
trace_dir/
*.csv


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2025 Jonathan Chang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

================================================
FILE: README.md
================================================
# flex-nano-vllm

FlexAttention based, minimal vllm-style inference engine for fast Gemma 2 inference.

## Introduction

This project has no flash-attn dependency, no custom triton kernel. Everything is implemented with FlexAttention. The code is commented, the structure is flat. Read the accompanying write-up: [vLLM flex attention from scratch](https://jonathanc.net/blog/vllm-flex-attention-from-scratch).

## Code Structure

```
flex-nano-vllm/
├── benchmark.py                   # Testing and benchmarking script.
├── benchmark_vllm.py              # vLLM comparison benchmark (uses uv inline dependency to run vLLM).
├── visualize.py                   # Performance visualization script.
└── flex_nano_vllm/
    ├── inference.py               # Main inference engine, uses paged attention.
    ├── modeling_gemma2.py         # Gemma2 model implementation, copied from transformers.
    └── paged_attention.py         # Paged attention implementation, including page table and paged kv cache. Based on attention-gym.
```

## Quick Start

```
uv sync

# run test and benchmark
uv run benchmark.py

# compare with vllm
uv run benchmark_vllm.py

# enable profiling to save more metrics to a csv file
# ENABLE_PROFILING=1 uv run benchmark_vllm.py
```


## Results

Test configuration:
- PyTorch version: 2.7.1+cu128
- GPU: RTX 3090 x 1 (24GB)
- Model: google/gemma-2-2b
- Workload: 512 requests, max 512 input tokens, variable output tokens (128-512)
- Configs tested: vLLM at 50% & 90% GPU memory, flex-nano-vllm with same page allocation as vLLM

| Implementation | Output Tokens/s | Request/s | Total Throughput* |
|---------------|----------------|-----------|------------------|
| vLLM v1, 90% GPU memory, high batch size† | 3,840 | 17.67 | 7,234 | 
| vLLM v1, 90% GPU memory | 3,772 | 15.26 | 6,401 | 
| flex-nano-vllm, 90% GPU memory, high batch size† | 3,440 | 14.30 | 5,817 |
| flex-nano-vllm, 90% GPU memory | 3,076 | 13.06 | 5,382 |
| vLLM v1, 50% GPU memory | 3,020 | 13.74 | 5,448 | 
| flex-nano-vllm, 50% GPU memory | 2,313 | 9.96 | 4,068 |

*Total throughput includes both input and output tokens  
† High batch size means max_num_seqs=512 in vllm (maximum allowed concurrency)

![Performance Comparison](tokens_per_second_comparison.png)

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

Third-party code incorporated in this project retains its original licenses. See [THIRD_PARTY_LICENSES.md](THIRD_PARTY_LICENSES.md) for details.

## Acknowledgments

- [GeeeekExplorer/nano-vllm](https://github.com/GeeeekExplorer/nano-vllm): this project is inspired by nano-vllm.
- [pytorch-labs/attention-gym](https://github.com/pytorch-labs/attention-gym): The paged attention implementation is based on attention-gym.
- [huggingface/transformers](https://github.com/huggingface/transformers): I copied the gemma2 model from transformers and modified it to use flex attention / paged attention.
- [vllm-project/vllm](https://github.com/vllm-project/vllm): vLLM has support for flex attention backend, which helped me find a useful flag in flex_attention.


================================================
FILE: THIRD_PARTY_LICENSES.md
================================================
# Third Party Licenses

This project incorporates code from third-party open source projects. The following licenses apply to the respective components:

## Hugging Face Transformers

This project includes modified code from the transformers project:
- **Source**: https://github.com/huggingface/transformers
- **Files**: `modeling_gemma2.py` (Gemma2 model implementation)
- **License**: Apache License 2.0
- **Copyright**: Copyright 2024 Google Inc. HuggingFace Inc. team

### Apache License 2.0

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

## Attention Gym

This project includes modified code from the attention-gym project:
- **Source**: https://github.com/pytorch-labs/attention-gym  
- **Files**: `paged_attention.py` and related components
- **License**: BSD 3-Clause License
- **Copyright**: Copyright (c) 2023, Driss Guessous

### BSD 3-Clause License

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
   contributors may be used to endorse or promote products derived from
   this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

================================================
FILE: benchmark.py
================================================
"""
This script is used to test the correctness and benchmark the paged attention implementation.


simplified interface of Inference class:
Usage:
llm = Inference(...)
sequences = [Sequence(text) for text in texts]
llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True, capture_cudagraph=True)
outputs = debug_print_outputs(sequences, tokenizer)




## correctness:
1. run with different 2 sequences with different lengths, the output should match huggingface .generate()
        this output has been verified to be correct outside of this script
2. run the same sequence with .generate() and paged attention, the output should match
        this is to test the paged attention Inference class does not have side effects that can alter the output across .generate() calls

### correctness with dynamic batching
3. run with different number of requests, and the same 2 sequences are mixed in the batch
        the output should match 1.

### correctness with cuda graph
4. after cudagraph capture, run some batch of requests, and the same 2 sequences are mixed in the batch
    the output should match 1.

### tests we don't cover, but might be useful to have

- PageTable unit tests
- tests with output length longer than one page (128 tokens)

"""

from transformers import AutoTokenizer
import time
from flex_nano_vllm import Gemma2ForCausalLM
from flex_nano_vllm.inference import Inference, Sequence, SamplingParams
from benchmark_vllm import generate_benchmark_data, long_prompt, short_prompt

import torch
from rich.console import Console
from torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler

console = Console()
torch.set_float32_matmul_precision("high")


def get_profiler_context():
    activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
    profiler_context = profile(
        activities=activities,
        schedule=schedule(wait=0, warmup=10, active=10, repeat=10),
        on_trace_ready=tensorboard_trace_handler("trace_dir"),
        record_shapes=False,
        profile_memory=False,
        with_stack=True,
        with_flops=False,
    )
    return profiler_context


# long_prompt and short_prompt are now imported from bench_utils


def debug_print_outputs(sequences, tokenizer, slice=slice(None), reference=None, prefix_match=False):
    results = []
    for i, seq in enumerate(sequences[slice]):
        output_ids = seq.output_ids
        output_decoded = tokenizer.decode(output_ids, skip_special_tokens=True)
        input_decoded = tokenizer.decode(seq.input_ids, skip_special_tokens=False)

        # If reference provided, only print on mismatch
        if reference is not None and i < len(reference):
            # Check for exact match or prefix match
            if prefix_match:
                matches = output_decoded.startswith(reference[i])
            else:
                matches = output_decoded == reference[i]

            if matches:
                match_type = "prefix match" if prefix_match else "match"
                console.print(f"i={i} ✓ {match_type}", style="green", markup=False)
            else:
                mismatch_type = "PREFIX MISMATCH" if prefix_match else "MISMATCH"
                console.print(f"i={i} ✗ {mismatch_type}", style="red bold", markup=False)
                console.print(f"  expected: {reference[i][:32]}...", style="red", markup=False)
                console.print(f"  got:      {output_decoded[:32]}...", style="red", markup=False)
                console.print(f"  input:    {input_decoded}", style="dim", markup=False)
        else:
            # Normal detailed output when no reference
            console.print(f"{i=} {input_decoded=} {output_decoded[:32]=}", style="bold ", markup=False)

        results.append(output_decoded)
    return results


if __name__ == "__main__":
    # Load model and tokenizer
    model_id = "google/gemma-2-2b"
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model = Gemma2ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval()
    # model = torch.compile(model)

    B = 8
    max_new_tokens = 8
    # settings to match vLLM (see benchmark_vllm.py)
    paged_attn_max_batch_size = 256  # match vLLM max_num_seqs=256
    max_seq_length = 2048
    max_input_length = 1024
    token_allocation = 45_360  # 50% memory usage, this number is derived by running uv run benchmark_vllm.py and checking the logs (on a 3090)
    token_allocation = 140_432  # 90% memory usage, this number is derived by running uv run benchmark_vllm.py and checking the logs (on a 3090)

    prefill_length_limit = 1024 * 8  # helps control peak memory usage for prefill

    page_size = 128
    n_pages = int(token_allocation) // page_size

    print("initializing vllm inference")
    llm = Inference(
        model,
        tokenizer,
        max_batch_size=paged_attn_max_batch_size,
        max_seq_length=max_seq_length,
        n_pages=n_pages,
        kernel_options={"BLOCK_M": 32, "BLOCK_N": 32},
        prefill_length_limit=prefill_length_limit,
    )

    print("test 1")
    ## test 1
    sequences = [Sequence(long_prompt), Sequence(short_prompt)]
    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True)
    results = debug_print_outputs(sequences, tokenizer)
    del sequences
    torch.cuda.empty_cache()

    print("test 2, same batch")
    sequences = [Sequence(long_prompt), Sequence(short_prompt)]
    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True)
    results2 = debug_print_outputs(sequences, tokenizer, reference=results)
    for i in range(len(results)):
        assert results[i] == results2[i], f"{i=}, {results[i]=}, {results2[i]=}"
    del sequences
    torch.cuda.empty_cache()

    print("test 2.1: reverse order")
    sequences = [Sequence(short_prompt), Sequence(long_prompt)]
    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True)
    # Debug in the right order: [long_prompt, short_prompt] to match reference
    reordered_sequences = [sequences[1], sequences[0]]  # [long_prompt, short_prompt]
    results21 = debug_print_outputs(reordered_sequences, tokenizer, reference=results)
    for i in range(len(results)):
        assert results[i] == results21[i], f"{i=}, {results[i]=}, {results21[i]=}"
    del sequences
    torch.cuda.empty_cache()

    print("test 3: batch with other sequence")
    sequences = [Sequence(short_prompt), Sequence(long_prompt), Sequence("hi")]
    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True)
    # Debug just the sequences we care about in the right order: [long_prompt, short_prompt]
    comparison_sequences = [sequences[1], sequences[0]]  # [long_prompt, short_prompt]
    results3 = debug_print_outputs(comparison_sequences, tokenizer, reference=results)
    del sequences
    for i in range(len(results)):
        assert results[i] == results3[i], f"{i=}, {results[i]=}, {results3[i]=}"
    torch.cuda.empty_cache()

    print("test 4: batch with other sequence, capture cudagraph")
    sequences = [
        Sequence("this is a test messaage hello "),
        Sequence(short_prompt),
        Sequence("test"),
        Sequence(long_prompt),
        Sequence("hello world "),
    ]
    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True, capture_cudagraph=True)
    # Debug just the sequences we care about in the right order: [long_prompt, short_prompt]
    comparison_sequences = [sequences[3], sequences[1]]  # [long_prompt, short_prompt]
    results4 = debug_print_outputs(comparison_sequences, tokenizer, reference=results)
    del sequences
    for i in range(len(results)):
        assert results[i] == results4[i], f"{i=}, {results[i]=}, {results4[i]=}"
    torch.cuda.empty_cache()

    # Generate test batch for cudagraph
    test_requests = generate_benchmark_data(tokenizer, n_requests=4, max_input_length=max_input_length)
    sequences = [Sequence(req.text) for req in test_requests]
    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=16), use_tqdm=True, capture_cudagraph=True)
    # Just capture cudagraph, no need to show verbose output
    # debug_print_outputs(sequences, tokenizer)

    print("after cudagraph")
    test_requests2 = generate_benchmark_data(tokenizer, n_requests=4, max_input_length=max_input_length)
    sequences = [Sequence(req.text) for req in test_requests2] + [Sequence(long_prompt), Sequence(short_prompt)]
    print("replay cudagraph, & profile")
    with get_profiler_context() as prof:
        llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=16), use_tqdm=True, profiler=prof)
    results_prefill = debug_print_outputs(sequences, tokenizer, slice=slice(-2, None), reference=results, prefix_match=True)
    for i in range(len(results_prefill)):
        assert results_prefill[i][: len(results[i])] == results[i], f"{i=}, {results[i]=}, {results_prefill[i][:len(results[i])]=}"
    del sequences
    torch.cuda.empty_cache()

    ## benchmark throughput
    n_requests = 512
    max_input_length = 512
    # Use shared benchmark data generation
    benchmark_requests = generate_benchmark_data(
        tokenizer,
        n_requests=n_requests,
        max_input_length=max_input_length,
    )

    # Convert to flex-nano-vllm format
    sequences = [Sequence(req.text) for req in benchmark_requests]
    sampling_params = [SamplingParams(max_new_tokens=req.max_new_tokens) for req in benchmark_requests]

    print("\n--- RUNNING BENCHMARK ---")

    # Reset memory stats to track benchmark-specific usage
    torch.cuda.reset_peak_memory_stats()

    start_time = time.time()
    llm.generate(sequences, sampling_params=sampling_params, use_tqdm=False, save_metrics_csv=True, print_stats=True)
    total_time = time.time() - start_time

    total_output_length = sum(seq.output_length for seq in sequences)
    total_input_length = sum(len(seq.input_ids) for seq in sequences)

    # Get memory usage
    peak_memory_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
    current_memory_mb = torch.cuda.memory_allocated() / 1024 / 1024

    print("\n--- PERFORMANCE METRICS ---")
    print(f"Total time: {total_time:.2f}s")
    print(f"Throughput: {total_output_length / total_time:.1f} tokens/s")
    print(f"Request throughput: {len(sequences) / total_time:.2f} req/s")
    print(f"Total throughput (prompt+new): {(total_input_length + total_output_length) / total_time:.1f} tokens/s")
    print(f"Peak memory: {peak_memory_mb:.1f} MB")
    print(f"Current memory: {current_memory_mb:.1f} MB")

    print("\nafter benchmark")
    results_final = debug_print_outputs(sequences, tokenizer, slice=slice(n_requests, n_requests + 2), reference=results, prefix_match=True)

    # Verify correctness
    for i in range(len(results_final)):
        assert results[i] == results_final[i][: len(results[i])], f"{i=}, {results[i]=}, {results_final[i][:len(results[i])]=}"


================================================
FILE: benchmark_vllm.py
================================================
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "vllm",
#     "transformers",
#     "datasets",
#     "matplotlib",
#     "tqdm",
#     "pandas",
# ]
# ///

# Usage:
# Run benchmark with minimal overhead: python bench_utils.py
# Run benchmark with profiling: ENABLE_PROFILING=1 python bench_utils.py
# NOTE: this script serves 2 purposes:
# 1. it can be used to benchmark vLLM's performance, run with isolated inline dependencies.
# 2. outside of __main__, it contains utils for producing the same payload for benchmarking.

from tqdm import tqdm
from datasets import load_dataset
import random
from dataclasses import dataclass

@dataclass
class BenchmarkRequest:
    """Simple, framework-agnostic request data"""
    text: str
    max_new_tokens: int

# Standard prompts used in benchmarks
long_prompt = """
The 12 months of the year are: January, February, March,
""".strip()

short_prompt = "The first 20 prime numbers are: 2, 3,"


def generate_benchmark_data(tokenizer, n_requests=512, max_input_length=512, min_tokens=128, max_tokens=512):
    """Generate benchmark data by skipping prompts that are too long."""
    from datasets import load_dataset
    
    data = load_dataset("Open-Orca/OpenOrca")["train"]
    benchmark_requests = []
    
    attempts = 0
    while len(benchmark_requests) < n_requests:
        # Deterministic sampling using hash
        idx = hash(f"req_{attempts}") % len(data)
        
        system_prompt = data[idx]["system_prompt"] or ""
        question = data[idx]["question"] or ""
        prompt = f"{idx}: {system_prompt} {question}".strip()
        
        # Check length and skip if too long
        tokens = tokenizer.encode(prompt)
        if len(tokens) <= max_input_length:
            prompt_hash = hash(prompt)
            max_new_tokens = min_tokens + (abs(prompt_hash >> 16) % (max_tokens - min_tokens + 1))
            benchmark_requests.append(BenchmarkRequest(text=prompt, max_new_tokens=max_new_tokens))
        
        attempts += 1
        if attempts > n_requests * 10:  # Safety valve to prevent infinite loop
            break
    
    # Add standard prompts
    benchmark_requests.extend([
        BenchmarkRequest(text=long_prompt, max_new_tokens=max_tokens),
        BenchmarkRequest(text=short_prompt, max_new_tokens=max_tokens)
    ])
    
    return benchmark_requests


def print_step_stats(steps, name):
    """Helper to print timing statistics for a collection of steps."""
    if not steps:
        return
    print(f"\n{name}:")
    print(f"  Count: {len(steps)}")
    print(f"  Total: {sum(steps):.4f}s")
    print(f"  Mean:  {sum(steps)/len(steps):.4f}s")
    print(f"  Min:   {min(steps):.4f}s")
    print(f"  Max:   {max(steps):.4f}s")


def generate_with_timing(llm, sequences, sampling_params, collect_detailed_metrics=False):
    """
    Generate with timing, optionally collecting detailed metrics.
    
    Note: We track total step time rather than trying to separate prefill/decode
    because vLLM can do both types of work within a single step, making such
    separation misleading for performance analysis.
    """
    outputs = []
    total_step_time = 0.0
    step_times = []
    
    # Optional detailed metrics
    metrics_data = {} if not collect_detailed_metrics else {
        'steps': [], 'requests_running': [], 'requests_waiting': [], 'preemptions': []
    }
    
    # Add requests
    for i, prompt in enumerate(sequences):
        sp = sampling_params[i] if isinstance(sampling_params, list) else sampling_params
        llm.llm_engine.add_request(str(i), prompt, sp)
    
    step_count = 0
    
    while llm.llm_engine.has_unfinished_requests():
        step_start = time.perf_counter()
        step_outputs = llm.llm_engine.step()
        step_duration = time.perf_counter() - step_start
        step_count += 1
        
        total_step_time += step_duration
        step_times.append(step_duration)
        
        # Collect outputs
        for output in step_outputs:
            if output.finished:
                outputs.append(output)
                
        # Optional detailed metrics
        if collect_detailed_metrics:
            metrics = llm.llm_engine.get_metrics()
            metrics_data['steps'].append(step_count)
            # Extract key metrics
            running = waiting = preemptions = 0
            for metric in metrics:
                if "requests_running" in metric.name:
                    running = metric.value
                elif "requests_waiting" in metric.name:
                    waiting = metric.value
                elif "preemptions" in metric.name:
                    preemptions = metric.value
            metrics_data['requests_running'].append(running)
            metrics_data['requests_waiting'].append(waiting)
            metrics_data['preemptions'].append(preemptions)
    
    return total_step_time, step_times, outputs, metrics_data


if __name__ == "__main__":
    import os
    import time
    
    os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
    os.environ['VLLM_USE_V1'] = '1'
    
    ENABLE_PROFILING = os.environ.get("ENABLE_PROFILING", "0") == "1"
    
    from vllm import LLM, SamplingParams as VLLMSamplingParams
    from transformers import AutoTokenizer

    MODEL_ID = "google/gemma-2-2b"
    llm = LLM(MODEL_ID, dtype="bfloat16", gpu_memory_utilization=0.9, max_num_seqs=256, max_model_len=2048, disable_log_stats=False)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    
    n_requests = 512
    max_input_length = 512

    benchmark_requests = generate_benchmark_data(tokenizer, n_requests, max_input_length)
    
    sequences = [req.text for req in benchmark_requests]
    sampling_params_list = [
        VLLMSamplingParams(temperature=0.0, top_p=1.0, max_tokens=req.max_new_tokens)
        for req in benchmark_requests
    ]
    
    # Warmup
    llm.generate(["warmup"], VLLMSamplingParams(max_tokens=1), use_tqdm=False)



    print(f"\n--- RUNNING {'WITH' if ENABLE_PROFILING else 'WITHOUT'} DETAILED METRICS ---")
    
    # Reset memory stats to track benchmark-specific usage  
    start_time = time.time()
    total_step_time, step_times, outputs, metrics_data = generate_with_timing(
        llm, sequences, sampling_params_list, collect_detailed_metrics=ENABLE_PROFILING
    )
    total_time = time.time() - start_time

    total_output_length = sum(len(o.outputs[0].token_ids) for o in outputs)
    prompt_tok = sum(len(o.prompt_token_ids) for o in outputs)
    
    print_step_stats(step_times, "\nstep")
    
    # Get memory usage
    print("\n--- PERFORMANCE METRICS ---")
    print(f"Total time: {total_time:.2f}s")
    print(f"Throughput: {total_output_length / total_time:.1f} tokens/s")
    print(f"Request throughput: {len(sequences) / total_time:.2f} req/s")
    print(f"Total throughput (prompt+new): {(prompt_tok + total_output_length) / total_time:.1f} tokens/s")
        
    if ENABLE_PROFILING and metrics_data:
        print("\n--- DETAILED METRICS ---")
        import pandas as pd
        pd.DataFrame(metrics_data).to_csv('vllm_metrics.csv', index=False)
        print("Metrics saved to 'vllm_metrics.csv'")



================================================
FILE: flex_nano_vllm/__init__.py
================================================
"""flex-nano-vllm - Flex-attention based nano-vllm implementation for fast PaliGemma inference."""

from .modeling_gemma2 import Gemma2ForCausalLM
from .inference import Inference, Sequence

__version__ = "0.1.0"
__all__ = ["Gemma2ForCausalLM", "Inference", "Sequence"]


================================================
FILE: flex_nano_vllm/inference.py
================================================
from collections import deque
import time
import torch
from tqdm import tqdm
from torch.nn.attention.flex_attention import BlockMask
import torch.nn.attention.flex_attention
import torch.nn.functional as F
from rich.console import Console

from flex_nano_vllm.paged_attention import PageTable, PagedKVCache

from dataclasses import dataclass

console = Console()
print(f"torch version: {torch.__version__}")


@dataclass
class SamplingParams:
    max_new_tokens: int = -1


def sample(logits_BV, greedy=True, to_cpu=False):
    # NOTE: use greedy=True to ensure deterministic sampling
    assert logits_BV.ndim == 2
    B, V = logits_BV.shape
    probs = torch.softmax(logits_BV, dim=-1)
    if not greedy:
        indices = torch.multinomial(probs, num_samples=1)  # shape: [B, 1]
        logits = torch.gather(logits_BV, dim=-1, index=indices)
        probs = torch.gather(probs, dim=-1, index=indices)
    else:
        probs, indices = torch.topk(probs, k=1, dim=-1)
        logits = torch.gather(logits_BV, dim=-1, index=indices)
    if to_cpu:
        indices = indices.to("cpu", non_blocking=True).view(B)
        logits = logits.to("cpu", non_blocking=True).view(B)
        probs = probs.to("cpu", non_blocking=True).view(B)
        torch.cuda.synchronize()
    return indices.tolist(), logits.tolist(), probs.tolist()


class Sequence:
    def __init__(self, text: str):
        self.done = False
        self.text = text
        self._output_ids = []
        self._output_logits = []
        self._output_probs = []
        self.input_ids = []
        self.finished = False

        self.input_length = None
        self.inputs = None

    def add_next_token(self, token_id: torch.Tensor, logits: torch.Tensor, probs: torch.Tensor):
        #assert token_id.ndim == 0
        #assert logits.ndim == 0
        self._output_ids.append(token_id)
        self._output_logits.append(logits)
        self._output_probs.append(probs)

    def copy(self):
        return Sequence(self.text)

    @property
    def output_ids(self):
        return torch.tensor(self._output_ids, dtype=torch.int64)

    @property
    def output_logits(self):
        return torch.tensor(self._output_logits, dtype=torch.float32)

    @property
    def output_probs(self):
        return torch.tensor(self._output_probs, dtype=torch.float32)

    @property
    def output_length(self):
        return len(self._output_ids)

    @property
    def total_length(self):
        return self.input_length + self.output_length

    @property
    def total_token_ids(self):
        if self.output_length:
            return torch.cat([self.input_ids, self.output_ids], dim=0)
        return self.input_ids

    @property
    def last_token_id(self):
        return self._output_ids[-1]


def process_sampling_params(sequences: list[Sequence], sampling_params: SamplingParams | list[SamplingParams] | None):
    if sampling_params is None:
        sampling_params = SamplingParams()
    if isinstance(sampling_params, SamplingParams):
        sampling_params = [sampling_params] * len(sequences)

    assert len(sampling_params) == len(sequences), "sampling_params must be a list of the same length as sequences"

    for seq, param in zip(sequences, sampling_params):
        seq.params = param


class Inference:
    def __init__(self, model, tokenizer, max_batch_size, max_seq_length, n_pages, page_size=128, prefill_length_limit=-1, kernel_options=None):
        self.page_table = PageTable(n_pages=n_pages, page_size=page_size, max_batch_size=max_batch_size)

        self.model = model
        self.tokenizer = tokenizer
        self.eos_token_id = tokenizer.eos_token_id # cache this because it's not efficient to call tokenizer.eos_token_id every time
        self.device = model.device
        assert max_seq_length % page_size == 0, "max_seq_length must be divisible by page_size"
        self.max_seq_length = max_seq_length
        self.max_batch_size = max_batch_size
        self.kernel_options = kernel_options
        self.prefill_length_limit = prefill_length_limit  # NOTE: control the peak memory usage of prefill

        for layer in self.model.model.layers:
            layer.self_attn.kv_cache = PagedKVCache(
                self.page_table,
                n_heads=self.model.model.config.num_key_value_heads,
                head_dim=self.model.model.config.head_dim,
                dtype=self.model.dtype,
            ).to(self.device, non_blocking=True)

        self.cudagraph_captured = False

        self.input_pos = torch.zeros(self.max_batch_size, dtype=torch.int32, pin_memory=True).to(self.device, non_blocking=True)
        self.block_mask = self.page_table.create_causal_blockmask(B=self.max_batch_size, L=self.max_seq_length)

    def _prefill_sequences(
        self, input_ids: torch.Tensor, input_pos: torch.Tensor, batch_idx_tensor: torch.Tensor, logits_to_keep: torch.Tensor
    ) -> torch.Tensor:
        # 1. no cuda graph
        # 2. construct block mask and apply it in logical space
        # 3. only write to kv cache, no read

        # NOTE: for batch/packed prefill, we need to pass batch_idx_tensor as [1, L]
        # input_ids is [1, L], concatenated from all sequences
        # batch_idx_tensor is [1, L]
        # position_ids is [1, L]
        # logits_to_keep is [num_sequences] instead of [1]

        ## padding: if there's padding
        # input_ids should be padded with any valid token id
        # input_pos should be padded with 0
        # batch_idx_tensor should be padded with 0 # reserved in page table

        assert input_ids.shape[0] == 1, "input_ids must be [1, L]"
        assert input_pos.shape == input_ids.shape, f"input_pos must be [1, L], got {input_pos.shape=}, {input_ids.shape=}"
        assert batch_idx_tensor.shape == input_ids.shape, f"batch_idx_tensor must be [1, L], got {batch_idx_tensor.shape=}, {input_ids.shape=}"

        mask = self.page_table.create_prefill_blockmask_no_paging(batch_idx_tensor)
        outputs = self.model.model(
            input_ids=input_ids,
            position_ids=input_pos + 1,  # NOTE: gemma2 uses 1-based position ids
            # logits_to_keep=logits_to_keep,
            flex_attn_block_mask=mask,
            flex_attn_input_pos=input_pos,
            flex_attn_batch_idx=batch_idx_tensor,
            flex_attn_kernel_options=self.kernel_options
            | {"FORCE_USE_FLEX_ATTENTION": True},  # NOTE: force torch compile to not use flash decoding code path
        )
        return self.model.lm_head(outputs.last_hidden_state[:, logits_to_keep, :])

        """
        outputs = self.model(
            input_ids=input_ids,
            position_ids=input_pos + 1, # NOTE: gemma2 uses 1-based position ids
            logits_to_keep=logits_to_keep,
            flex_attn_block_mask=mask,
            flex_attn_input_pos=input_pos,
            flex_attn_batch_idx=batch_idx_tensor,
            flex_attn_kernel_options=self.kernel_options | {'FORCE_USE_FLEX_ATTENTION': True}, # NOTE: force torch compile to not use flash decoding code path
        )
        return outputs.logits
        """

    def prefill_sequences(self, sequences: list[Sequence]) -> torch.Tensor:
        input_ids = torch.cat([seq.total_token_ids for seq in sequences], dim=0)
        input_pos = torch.cat([torch.arange(seq.total_length, dtype=torch.long) for seq in sequences], dim=0)
        batch_idx_tensor = torch.cat([torch.ones(seq.total_length, dtype=torch.long) * seq.batch_idx for seq in sequences], dim=0)
        input_lengths = torch.tensor([seq.total_length for seq in sequences], dtype=torch.int32).to(self.device, non_blocking=True)
        logits_to_keep = input_lengths.cumsum(dim=0) - 1

        num_pad = 128 - input_ids.shape[0] % 128
        if num_pad > 0:
            input_ids = F.pad(input_ids.view(-1), (0, num_pad), mode="constant", value=0)
            input_pos = F.pad(input_pos.view(-1), (0, num_pad), mode="constant", value=0)
            batch_idx_tensor = F.pad(batch_idx_tensor.view(-1), (0, num_pad), mode="constant", value=0)
            # logits_to_keep is not padded, it should have shape [num_sequences]

        input_ids = input_ids.view(1, -1).to(self.device, non_blocking=True)
        input_pos = input_pos.view(1, -1).to(self.device, non_blocking=True)
        batch_idx_tensor = batch_idx_tensor.view(1, -1).to(self.device, non_blocking=True)
        logits_to_keep = logits_to_keep.view(-1).to(self.device, non_blocking=True)

        logits = self._prefill_sequences(input_ids, input_pos, batch_idx_tensor, logits_to_keep)
        return logits

    def get_decoding_block_mask(self, batch_idx: torch.Tensor):
        """
        Args:
            batch_idx: [B]
        Returns:
            block_mask: [B, H, ROWS=1, MAX_BLOCKS_IN_COL]
            input_pos: [B]

        This function slices the
            full block mask self.block_mask:  [max_batch_size, H, MAX_BLOCKS_IN_ROW, MAX_BLOCKS_IN_COL]
            using self.input_pos: [max_batch_size]
            and batch_idx: [B]
        """

        # NOTE: this function is entirely in logical space
        def causal_offset(off: torch.Tensor):
            def offset(b, h, q_idx, kv_idx):
                return q_idx + off[b] >= kv_idx

            return offset

        block_mask = self.block_mask
        input_pos = self.input_pos[batch_idx]
        # batch_idx: [B], input_pos: [B]
        assert batch_idx.ndim == 1, "batch_idx must be 1D"
        assert input_pos.ndim == 1, "input_pos must be 1D"
        (B,) = batch_idx.shape
        input_block_idx = input_pos // block_mask.BLOCK_SIZE[0]  # [B]
        kv_num_blocks = block_mask.kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1)
        kv_indices = block_mask.kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1)
        full_kv_num_blocks, full_kv_indices = None, None
        if block_mask.full_kv_num_blocks is not None:
            full_kv_num_blocks = block_mask.full_kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1)  # noqa
            full_kv_indices = block_mask.full_kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1)  # noqa
        seq_length = (1, block_mask.seq_lengths[1])
        mask = BlockMask.from_kv_blocks(
            kv_num_blocks,
            kv_indices,
            full_kv_num_blocks,
            full_kv_indices,
            BLOCK_SIZE=block_mask.BLOCK_SIZE,
            mask_mod=causal_offset(input_pos),
            seq_lengths=seq_length,
        )
        return mask, input_pos

    def _decode_step(self, batch_idx: torch.Tensor, input_ids: torch.Tensor):
        B = input_ids.shape[0]
        mask, input_pos = self.get_decoding_block_mask(batch_idx)
        mask = self.page_table.convert_logical_block_mask(mask, batch_idx)
        outputs = self.model(
            input_ids=input_ids.view(B, 1),
            position_ids=(input_pos + 1).view(B, 1),  # NOTE: position_ids is needed for decoding. For Gemma2, it's 1-based
            flex_attn_block_mask=mask,
            flex_attn_input_pos=input_pos.view(B, 1),
            flex_attn_batch_idx=batch_idx.view(-1),
            flex_attn_kernel_options=self.kernel_options,
        )
        return outputs.logits

    def decode_step(self, batch_idx: torch.Tensor, input_ids: torch.Tensor, input_pos: torch.Tensor):
        assert input_ids.ndim == 1, "input_ids must be 1D"
        assert batch_idx.ndim == 1, "batch_idx must be 1D"
        assert input_ids.shape[0] == batch_idx.shape[0], "input_ids and batch_idx must have the same length"
        self.input_pos.zero_()
        self.input_pos[batch_idx] = input_pos
        if not self.cudagraph_captured:
            return self._decode_step(batch_idx, input_ids)
        else:
            bs = input_ids.size(0)
            graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
            graph_vars = self.graph_vars
            for k, v in graph_vars.items():
                if k != "outputs":
                    v.zero_()  # batch idx is 0 for undefined batch idx, so we never overwrite any kv-cache
            graph_vars["input_ids"][:bs] = input_ids
            graph_vars["batch_idx"][:bs] = batch_idx
            graph.replay()
            replayed = graph_vars["outputs"][:bs]
            return replayed

    def _check_done(self, sequences: list[Sequence]):
        for seq in sequences:
            is_eos = seq.last_token_id == self.eos_token_id
            is_max_len = seq.input_length + seq.output_length >= self.max_seq_length
            is_max_new = seq.output_length == seq.params.max_new_tokens
            if is_eos or is_max_len or is_max_new:
                seq.finished = True
                self.done.append(seq)
                self.page_table.erase(seq.batch_idx)
        return [seq for seq in sequences if not seq.finished]

    def run_one_step(self):
        # try to prefill a new sequence, if we can schedule it
        batch = []
        prefill_length_sum = 0
        # NOTE: for each sequence, only allocate & reserve just enough pages for the current sequence length
        # We do not reserve pages for the future tokens. This allows maximizing the batch size of decoding.
        # We support preemption in case we run out of pages during decoding
        while (
            self.waiting
            and self.page_table.can_reserve(self.waiting[0].total_length)  # + self.waiting[0].params.max_new_tokens)
            and (not batch or self.prefill_length_limit == -1 or prefill_length_sum + self.waiting[0].total_length < self.prefill_length_limit)
        ):
            prefill_length_sum += self.waiting[0].total_length

            seq = self.waiting.popleft()
            batch_idx_int = self.page_table.allocate()
            batch_idx_tensor = torch.tensor([batch_idx_int], device=self.device, dtype=torch.long)
            self.page_table.reserve(batch_idx_int=batch_idx_int, batch_idx=batch_idx_tensor, seq_len=seq.total_length)  # + seq.params.max_new_tokens
            seq.batch_idx = batch_idx_int
            batch.append(seq)
        if batch:
            logits_1LV = self.prefill_sequences(batch)
            next_token, logits, probs = sample(logits_1LV[0, :, :], to_cpu=True)
            for b in range(len(batch)):
                batch[b].add_next_token(next_token[b], logits[b], probs[b])
            self.running.extend(self._check_done(batch))
            return "prefill"
        # reserve new block for running sequences if needed
        # if a new block is needed, but there's no space, preempt the newest sequence
        # running: [oldest, ... , newest]  waiting: [oldest, ... , newest]
        while self.running:
            seq = self.running.popleft()
            if self.page_table.capacity[seq.batch_idx] >= seq.total_length:
                # no need to reserve new pages
                batch.append(seq)
            elif self.page_table.can_reserve(seq.total_length, batch_idx_int=seq.batch_idx):
                # reserve new pages
                self.page_table.reserve(
                    batch_idx_int=seq.batch_idx,
                    batch_idx=torch.tensor([seq.batch_idx], device=self.device, dtype=torch.long),
                    seq_len=seq.total_length,
                )
                batch.append(seq)
            else:
                # no space to run this sequence, preempt the newest sequence
                self.running.appendleft(seq)  # first put this sequence back
                newest = self.running.pop()  # then pop the newest sequence
                self.waiting.appendleft(newest)
                self.page_table.erase(newest.batch_idx)

        B = len(batch)
        # now we do decoding
        batch_idx = torch.tensor([seq.batch_idx for seq in batch], dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True)
        input_ids = torch.tensor([seq.last_token_id for seq in batch], dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True)
        input_pos = torch.tensor([seq.total_length - 1 for seq in batch], dtype=torch.int32, pin_memory=True).to(self.device, non_blocking=True)
        self.counts.append(B)
        logits_BLV = self.decode_step(batch_idx, input_ids, input_pos)
        next_token, logits, probs = sample(logits_BLV[:, -1, :], to_cpu=True)

        for i in range(B):
            batch[i].add_next_token(next_token[i], logits[i], probs[i])
        self.running = deque(self._check_done(batch))
        return "decode"

    def tokenize(self, sequences: list[Sequence]):
        self.tokenizer.padding_side = "right"
        for seq in sequences:
            seq.input_ids = self.tokenizer([seq.text], return_tensors="pt")["input_ids"].squeeze(0)
            seq.input_length = seq.input_ids.shape[0]

    @torch.inference_mode()
    def generate(
        self,
        sequences: list[Sequence],
        use_tqdm=False,
        profiler=None,
        greedy=False,
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        capture_cudagraph=False,
        save_metrics_csv=False,
        print_stats=False,
    ):
        self.counts = []
        self.metrics_data = {"step": [], "requests_running": [], "requests_waiting": [], "step_type": []}
        # preprocess the sequences
        self.tokenize(sequences)
        self.waiting = deque(sequences)
        self.running = deque()
        self.done = deque()
        process_sampling_params(sequences, sampling_params)

        if capture_cudagraph and not self.cudagraph_captured:
            self.capture_decode_cudagraph()
            self.cudagraph_captured = True

        total_sequences = len(self.waiting)
        times = []
        step_count = 0
        with tqdm(total=total_sequences, disable=not use_tqdm, desc="Generating") as pbar:
            prev_done = 0
            while self.waiting or self.running:
                step_count += 1
                # Track metrics before step
                self.metrics_data["step"].append(step_count)
                self.metrics_data["requests_running"].append(len(self.running))
                self.metrics_data["requests_waiting"].append(len(self.waiting))

                time_start = time.perf_counter()
                step_type = self.run_one_step()
                time_end = time.perf_counter()

                # Track step type
                self.metrics_data["step_type"].append(step_type)
                times.append({"step_type": step_type, "time": time_end - time_start})
                if profiler:
                    profiler.step()
                # Update progress bar based on newly completed sequences
                curr_done = len(self.done)
                if curr_done > prev_done:
                    pbar.update(curr_done - prev_done)
                    prev_done = curr_done
        if print_stats:
            self.print_time_stats(times)

        # Save metrics to CSV if requested
        if save_metrics_csv:
            import pandas as pd

            df = pd.DataFrame(self.metrics_data)
            df.to_csv("flex_nano_vllm_metrics.csv", index=False)
            print("Metrics saved to 'flex_nano_vllm_metrics.csv'")

    def capture_decode_cudagraph(self):
        """
        capture cudagraph for decoding
        """
        max_bs = self.max_batch_size
        input_ids = torch.zeros(max_bs, dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True)
        batch_idx = torch.arange(max_bs, dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True)
        # NOTE: here we use logits as the final output, but we can consider using last hidden state as the output
        outputs = torch.zeros((max_bs, 1, self.model.model.config.vocab_size), pin_memory=True).to(self.device, non_blocking=True)
        self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))  # 8 vs 16 doesn't make much difference
        self.graphs = {}
        self.graph_pool = None

        for bs in reversed(self.graph_bs):
            print(f"capturing cudagraph for {bs} sequences")
            torch.cuda.synchronize()
            graph = torch.cuda.CUDAGraph()
            # warmup
            outputs[:bs] = self._decode_step(batch_idx[:bs], input_ids[:bs])  # warmup
            # capture
            with torch.cuda.graph(graph, self.graph_pool):
                outputs[:bs] = self._decode_step(batch_idx[:bs], input_ids[:bs])  # capture
            if self.graph_pool is None:
                self.graph_pool = graph.pool()
            self.graphs[bs] = graph
            torch.cuda.synchronize()

        self.graph_vars = dict(
            input_ids=input_ids,
            batch_idx=batch_idx,
            outputs=outputs,
        )
        # in our code, the page table tensors are modified in-place, so we don't need to put them in graph vars

    def print_time_stats(self, times):
        stats = {}
        for step in ["decode", "prefill"]:
            step_times = [t["time"] for t in times if t["step_type"] == step]
            stats[step] = {
                "count": len(step_times),
                "total": sum(step_times),
                "mean": sum(step_times) / len(step_times) if step_times else 0,
                "min": min(step_times) if step_times else 0,
                "max": max(step_times) if step_times else 0,
            }

        print("\nTime statistics by step type:")
        for step, metrics in stats.items():
            print(f"\n{step}:")
            print(f"  Count: {metrics['count']}")
            print(f"  Total: {metrics['total']:.4f}s")
            print(f"  Mean:  {metrics['mean']:.4f}s")
            print(f"  Min:   {metrics['min']:.4f}s")
            print(f"  Max:   {metrics['max']:.4f}s")
        print(f"\nTotal time: {sum(t['time'] for t in times):.4f}s")


================================================
FILE: flex_nano_vllm/modeling_gemma2.py
================================================
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_gemma2.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union

import torch
import torch.nn as nn

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config

from torch.nn.attention.flex_attention import flex_attention

flex_attention = torch.compile(flex_attention, fullgraph=True)

logger = logging.get_logger(__name__)


class Gemma2RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"

class Gemma2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_activation]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    softcap: Optional[float] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    if scaling is None:
        scaling = module.head_dim**-0.5

    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

    if softcap is not None:
        attn_weights = attn_weights / softcap
        attn_weights = torch.tanh(attn_weights)
        attn_weights = attn_weights * softcap
    if attention_mask is not None:  # no matter the length, we just slice it
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights


class Gemma2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: Gemma2Config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.scaling = config.query_pre_attn_scalar**-0.5
        self.attention_dropout = self.config.attention_dropout
        self.is_causal = True

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
        )
        self.k_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.v_proj = nn.Linear(
            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
        )
        self.o_proj = nn.Linear(
            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
        )
        self.attn_logit_softcapping = self.config.attn_logit_softcapping
        self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
        self.kv_cache = None

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        flex_attn_block_mask = None,
        flex_attn_input_pos = None,
        flex_attn_batch_idx = None,
        flex_attn_kernel_options = {'BLOCK_M': 32, 'BLOCK_N': 32, 'BLOCK_M1': 32, 'BLOCK_M2': 32, 'BLOCK_N1': 32, 'BLOCK_N2': 32, },
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # NOTE: this does not cover the sliding window case, but in my current usage, the sequence length does not exceed 4096
        def soft_cap(score, b, h, q_idx, kv_idx):
            score = score / self.attn_logit_softcapping
            score = torch.tanh(score)
            score = score * self.attn_logit_softcapping
            return score

        if self.kv_cache is not None and flex_attn_input_pos is not None:
            key_states, value_states= self.kv_cache.update(flex_attn_input_pos, key_states, value_states, flex_attn_batch_idx)


        attn_output = flex_attention(
            query_states,
            key_states,
            value_states,
            #dropout=self.attention_dropout if self.training else 0.0,
            scale=self.scaling,
            block_mask=flex_attn_block_mask,
            score_mod=soft_cap,
            enable_gqa=True,
            kernel_options=flex_attn_kernel_options,
        )
        attn_weights = None
        attn_output = attn_output.transpose(1, 2) # (B, H, N, E) -> (B, N, H, E)

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights


class Gemma2DecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: Gemma2Config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.config = config
        self.attention_type = config.layer_types[layer_idx]
        self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
        self.mlp = Gemma2MLP(config)
        self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    @deprecate_kwarg("last_cache_position", version="4.53.0")
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            position_embeddings=position_embeddings,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs


class Gemma2RotaryEmbedding(nn.Module):
    def __init__(self, config: Gemma2Config, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
        position_ids_expanded = position_ids[:, None, :].float()

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


@auto_docstring
class Gemma2PreTrainedModel(PreTrainedModel):
    config: Gemma2Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Gemma2DecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn = True
    _supports_sdpa = True
    _supports_flex_attn = True

    _supports_static_cache = True
    _supports_attention_backend = True
    _can_record_outputs = {
        "hidden_states": Gemma2DecoderLayer,
        "attentions": Gemma2Attention,
    }

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, Gemma2RMSNorm):
            module.weight.data.fill_(1.0)


@auto_docstring
class Gemma2Model(Gemma2PreTrainedModel):
    def __init__(self, config: Gemma2Config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Gemma2RotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    @check_model_inputs
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[TransformersKwargs],
    ) -> BaseModelOutputWithPast:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None and not self.training:
            past_key_values = DynamicCache()

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # embed positions
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # normalized
        # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
        hidden_states = hidden_states * normalizer

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = decoder_layer(
                hidden_states,
                position_embeddings=position_embeddings,
                attention_mask=None,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                **kwargs,
            )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


@auto_docstring
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    _tp_plan = {"lm_head": "colwise_rep"}
    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

    def __init__(self, config):
        super().__init__(config)
        self.model = Gemma2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    @can_return_tuple
    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        r"""
        Example:

        ```python
        >>> from transformers import AutoTokenizer, Gemma2ForCausalLM

        >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")

        >>> prompt = "What is your favorite condiment?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "What is your favorite condiment?"
        ```"""

        if self.training and self.config._attn_implementation != "eager":
            logger.warning_once(
                "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
                f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
            )
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs: BaseModelOutputWithPast = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])
        if self.config.final_logit_softcapping is not None:
            logits = logits / self.config.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.config.final_logit_softcapping

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

__all__ = [
    "Gemma2ForCausalLM",
    "Gemma2Model",
    "Gemma2PreTrainedModel",
]


================================================
FILE: flex_nano_vllm/paged_attention.py
================================================
# Adapted from attention-gym
# Original source: https://github.com/pytorch-labs/attention-gym
# License: BSD 3-Clause (see THIRD_PARTY_LICENSES.md)
# Copyright (c) 2023, Driss Guessous

# the original implementation has some bugs and has some feature that lives outside of the PageTable class

from typing import Optional
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import (
    _identity,
    _mask_mod_signature,
    _score_mod_signature,
    BlockMask,
    noop_mask,
    create_block_mask,
)

create_block_mask = torch.compile(create_block_mask)


def _cdiv(x: int | float | torch.Tensor, multiple: int | float | torch.Tensor):
    return (x + multiple - 1) // multiple


class PagedKVCache(torch.nn.Module):
    def __init__(self, page_table, n_heads, head_dim, dtype):
        super().__init__()
        cache_shape = (1, n_heads, page_table.n_pages * page_table.page_size, head_dim)
        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

        self.page_table = page_table

    def update(self, input_pos, k_val, v_val, batch_idx=None):
        assert batch_idx is not None, "batch_idx is required for paged kv cache, are you using non-paged attention?"

        if batch_idx.ndim == 1:
            # batch_idx should be [B] (decode)
            return self.page_table.assign(batch_idx, input_pos, k_val, v_val, self.k_cache, self.v_cache)
        else:
            assert batch_idx.ndim == 2, "batch_idx must be 1D or 2D"
            # batch_idx should be [1, L] (batch prefill)
            return self.page_table.assign_prefill_no_paging(batch_idx, input_pos, k_val, v_val, self.k_cache, self.v_cache)


class PageTable:
    """
    PageTable is a modified version of PagedAttention from attention-gym.

    PageTable improves it by:
    - maintaining a cpu copy of the page table, to avoid device-to-host transfers
    - support batch prefill
    - fix the bug in the original code in mask_mod and score_mod by mapping physical batch index to logical batch index
    - subsuming the free_batch_idx into the page table, so we don't need to maintain it separately
    """

    def __init__(
        self,
        n_pages: int,
        page_size: int,
        max_batch_size: int,
        device: str = "cuda",
    ):
        self.n_pages = n_pages
        self.page_size = page_size
        self.max_batch_size = max_batch_size
        self.device = device

        # page table: [logical_batch_idx, logical_block_idx] -> physical_page_idx
        self.page_table = -torch.ones((max_batch_size, self.n_pages), dtype=torch.int64, device=device)
        self.page_table[0, :] = 0  # page 0 is reserved for simpler code in assign_prefill_no_paging
        self.page_table_cpu = [[] for _ in range(max_batch_size)]

        self.capacity = [0 for _ in range(max_batch_size)]  # capacity: batch_idx -> number of pages allocated * page size
        self.free_pages = list(reversed(range(1, n_pages)))  # page 0 is reserved for simpler code in assign_prefill_no_paging
        self.free_batch_idx = list(reversed(range(1, max_batch_size)))  # batch_idx 0 is reserved for no-op

        # [logical_batch_idx, physical_page_idx] -> logical_page_idx
        self.physical_to_logical = -torch.ones((max_batch_size, n_pages), dtype=torch.int64, device=device)

    def can_reserve(self, size: int, batch_idx_int: int | None = None) -> bool:
        """check if we can reserve new pages for an existing request or a new request, without gpu operations"""
        if batch_idx_int is None:
            # check if we can schedule a new request
            return self.pages_available * self.page_size >= size and len(self.free_batch_idx) > 0
        else:
            # check if we can reserve new pages for an existing request
            return self.reserve(batch_idx_int, None, size, dry_run=True)

    def allocate(self) -> int:
        """allocate a new batch"""
        batch_idx = self.free_batch_idx.pop()

        self.capacity[batch_idx] = 0
        self.physical_to_logical[batch_idx, :] = -1
        self.page_table[batch_idx, :] = -1
        return batch_idx

    @property
    def pages_available(self) -> int:
        return len(self.free_pages)

    def reserve(self, batch_idx_int: int, batch_idx: torch.Tensor, seq_len: int, dry_run: bool = False) -> bool:
        """
        Requests the capacity of a given batch to be at least enough to
        hold `seq_len` elements.

        Args:
            batch_idx_int (int): batch index to be reserved;
            batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`.
            seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`.

        Returns:
            bool: True if the reservation was successful, False if the reservation was not successful (no space, and in this case, no update is done)
        """

        if seq_len <= self.capacity[batch_idx_int]:
            return True

        num_pages_to_allocate = _cdiv(seq_len - self.capacity[batch_idx_int], self.page_size)

        can_allocate = num_pages_to_allocate <= self.pages_available
        if dry_run:
            return can_allocate

        if not can_allocate:
            raise RuntimeError(
                f"Cannot reserve {num_pages_to_allocate} pages for a sequence of length {seq_len} "
                f"in batch {batch_idx_int}. Only {self.pages_available} pages available. "
                f"Current capacity is {self.capacity[batch_idx_int]} tokens."
            )

        start_page_idx = self.capacity[batch_idx_int] // self.page_size
        end_page_idx = start_page_idx + num_pages_to_allocate

        # find empty physical pages
        allocated_pages_list = self.free_pages[-num_pages_to_allocate:]
        allocated_pages = torch.tensor(allocated_pages_list, device=self.device)
        # update page table
        self.page_table[batch_idx, start_page_idx:end_page_idx] = allocated_pages

        # update metadata
        self.physical_to_logical[batch_idx, allocated_pages] = torch.arange(
            start_page_idx,
            end_page_idx,
            device=self.device,
        )
        # update cpu side metadata
        self.page_table_cpu[batch_idx_int] += allocated_pages_list
        self.free_pages = self.free_pages[:-num_pages_to_allocate]
        self.capacity[batch_idx_int] += num_pages_to_allocate * self.page_size
        return True

    def erase(self, batch_idx: int) -> None:
        """
        Removes a single batch from paged attention.

        Args:
            batch_idx (int): batch index to be removed;
        """
        # NOTE: the GPU side data will only be reset/overwritten when we allocate it for a new batch
        self.free_batch_idx.append(batch_idx)
        allocated_pages_cpu = self.page_table_cpu[batch_idx]
        self.free_pages.extend(reversed(allocated_pages_cpu))
        self.page_table_cpu[batch_idx] = []

    def assign(
        self,
        batch_idx: torch.Tensor,
        input_pos: torch.Tensor,
        k_val: torch.Tensor,
        v_val: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
    ) -> None:
        """
        Assigns new contents `val` to the storage `cache` at the location
        `batch_idx` and `input_pos`.

        Args:
            batch_idx (Tensor): batch index; shape :math:`(B)`.
            input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`.
            val (Tensor): value to be assigned; shape :math:`(B, H, S, D)`
            cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)`
        """
        if k_val.requires_grad:
            raise RuntimeError("val must not require gradient")

        B, H, S, K_D = k_val.shape
        _, H_cache, MAX_S, D_cache = k_cache.shape
        assert H_cache == H, "number of heads must match"
        assert MAX_S >= S, "cache must have enough space"
        assert D_cache == K_D, "hidden dim must match"
        assert input_pos.shape == (B, S), "input_pos must have the same shape as val"
        assert batch_idx.shape == (B,), "batch_idx must have one dimension only"

        V_D = v_val.shape[3]
        if B != batch_idx.shape[0]:
            raise RuntimeError(f"Expect val and batch_idx have the same batch size but got B={B} and B={batch_idx.shape[0]}.")
        if H != k_cache.shape[1]:
            raise RuntimeError(f"Expect val and cache has the same number of heads but got H={H} and H={k_cache.shape[1]}.")
        if S != input_pos.shape[1]:
            raise RuntimeError(f"Expect val and input_pos has the same length but got S={S} and S={input_pos.shape[0]}.")
        if K_D != k_cache.shape[3]:
            raise RuntimeError(f"Expect k_val and k_cache has the same hidden dim but got D={K_D} and D={k_cache.shape[3]}.")
        if V_D != v_cache.shape[3]:
            raise RuntimeError(f"Expect v_val and v_cache has the same hidden dim but got D={V_D} and D={v_cache.shape[3]}.")

        # find address
        logical_block_idx = input_pos // self.page_size  # [B, S]
        logical_block_offset = input_pos % self.page_size  # [B, S]

        # NOTE: this code path is only used for decoding. For batch prefill, use assign_prefill_no_paging() instead
        physical_block_idx = torch.gather(self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)).to(torch.int32)  # [B, S]

        addr = (physical_block_idx * self.page_size + logical_block_offset).view(-1)  # [B*S]

        k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D)
        v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D)

        k_cache[:, :, addr, :] = k_val
        v_cache[:, :, addr, :] = v_val

        return k_cache, v_cache

    def convert_logical_block_mask(
        self,
        block_mask: BlockMask,
        batch_idx: Optional[torch.Tensor] = None,
    ) -> BlockMask:
        """
        Converts a logical block mask by mapping its logical kv indices to the corresponding
        physical kv indices.

        Args:
            block_mask (BlockMask): logical block mask;
                kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`.
            batch_idx (Tensor): batch index corresponding to the block_mask
                batch dimension. This provides flexibility to convert a
                block mask with smaller batch size than the page table;
                shape :math:`(B)`.
        """
        B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape

        if block_mask.BLOCK_SIZE[1] != self.page_size:
            raise RuntimeError(
                f"Expect block_mask has the same column block size as page_sizebut got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}"
            )

        device = block_mask.kv_num_blocks.device

        if batch_idx is None:
            batch_idx = torch.arange(B, device=device)

        assert batch_idx.ndim == 1, "batch_idx must be a 1D tensor"
        assert batch_idx.shape[0] == B, "batch_idx must have the same shape as block_mask"
        assert B <= self.max_batch_size, "batch_idx must be less than or equal to max_batch_size"

        page_table = self.page_table[batch_idx]

        def transform(num_blocks, indices):
            """
            transform the block mask from [B, H, num_q_blocks, num_logical_kv_blocks]
            to [B, H, num_q_blocks, num_physical_kv_blocks]

            kv_num_blocks: [B, H, num_q_blocks] -> unchanged
            kv_indices: [B, H, num_q_blocks, num_logical_kv_blocks] -> [B, H, num_q_blocks, num_physical_kv_blocks]
            """
            if num_blocks is None:
                return None, None
            new_kv_num_blocks = num_blocks.clone()
            new_kv_indices = torch.zeros((B, H, ROWS, self.n_pages), dtype=torch.int32, device=device)
            new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (
                torch.gather(page_table, 1, indices.view(B, -1).to(torch.int64)).view(block_mask.kv_indices.shape).to(torch.int32)
            )
            return new_kv_num_blocks, new_kv_indices

        new_kv_num_blocks, new_kv_indices = transform(block_mask.kv_num_blocks, block_mask.kv_indices)
        new_full_kv_num_blocks, new_full_kv_indices = transform(block_mask.full_kv_num_blocks, block_mask.full_kv_indices)

        new_mask_mod = self.get_mask_mod(block_mask.mask_mod, batch_idx)

        seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
        return BlockMask.from_kv_blocks(
            new_kv_num_blocks,
            new_kv_indices,
            new_full_kv_num_blocks,
            new_full_kv_indices,
            block_mask.BLOCK_SIZE,
            new_mask_mod,
            seq_lengths=seq_lengths,
        )

    def get_logical_kv_idx(self, physical_batch_idx: torch.Tensor, physical_kv_idx: torch.Tensor, batch_idx: torch.Tensor):
        logical_batch_idx = batch_idx[physical_batch_idx]
        physical_kv_block = physical_kv_idx // self.page_size
        physical_kv_offset = physical_kv_idx % self.page_size
        logical_block_idx = self.physical_to_logical[logical_batch_idx, physical_kv_block]
        logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
        is_valid = logical_block_idx >= 0
        safe_logical_kv_idx = logical_kv_idx.clamp(min=0)
        return is_valid, safe_logical_kv_idx

    def get_mask_mod(self, mask_mod: Optional[_mask_mod_signature], batch_idx: torch.Tensor) -> _mask_mod_signature:
        """
        Converts a mask_mod based on mapping from the physical block index to the logical
        block index.

        Args:
            mask_mod (_mask_mod_signature): mask_mod based on the logical block index.
        """
        if mask_mod is None:
            mask_mod = noop_mask

        def new_mask_mod(
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            physical_kv_idx: torch.Tensor,
        ):
            is_valid, safe_logical_kv_idx = self.get_logical_kv_idx(b, physical_kv_idx, batch_idx)
            return torch.where(is_valid, mask_mod(b, h, q_idx, safe_logical_kv_idx), False)

        return new_mask_mod

    # NOTE: not used in the current codebase
    def get_score_mod(self, score_mod: Optional[_score_mod_signature], batch_idx: torch.Tensor) -> _score_mod_signature:
        """
        Converts a score_mod based on mapping from the physical block index to the logical
        block index.

        Args:
            score_mod (_score_mod_signature): score_mod based on the logical block index.
        """
        if score_mod is None:
            score_mod = _identity

        def new_score_mod(
            score: torch.Tensor,
            b: torch.Tensor,
            h: torch.Tensor,
            q_idx: torch.Tensor,
            physical_kv_idx: torch.Tensor,
        ):
            is_valid, safe_logical_kv_idx = self.get_logical_kv_idx(b, physical_kv_idx, batch_idx)
            return torch.where(
                is_valid,
                score_mod(score, b, h, q_idx, safe_logical_kv_idx),
                float("-inf"),
            )

        return new_score_mod

    def create_causal_blockmask(self, B, L):
        """A minimal, unoptimized causal block mask creation function"""

        def causal(b, h, q_idx, kv_idx):
            return q_idx >= kv_idx

        return create_block_mask(causal, B=B, H=None, Q_LEN=L, KV_LEN=L, BLOCK_SIZE=self.page_size, device=self.device)

    def create_prefill_blockmask_no_paging(self, batch_idx: Tensor, BLOCK_SIZE: int = 128):
        """
        there's no prefix sharing implemented, batch_idx is the document id, batch_idx is not guaranteed to be sorted
        """
        assert batch_idx.ndim == 2, "batch_idx must be a 2D tensor"
        assert batch_idx.shape[0] == 1, "batch_idx must have batch size 1"
        L = batch_idx.shape[1]
        docs = batch_idx.view(-1)

        def document_causal(b, h, q_idx, kv_idx):
            causal_mask = q_idx >= kv_idx
            document_mask = docs[q_idx] == docs[kv_idx]
            return causal_mask & document_mask

        return create_block_mask(document_causal, B=1, H=None, Q_LEN=L, KV_LEN=L, BLOCK_SIZE=BLOCK_SIZE)

    # we assign prefill to the cache, similar to assign(), except we don't return the k_cache, v_cache, we only return the k_val, v_val
    def assign_prefill_no_paging(
        self,
        batch_idx: torch.Tensor,
        input_pos: torch.Tensor,
        k_val: torch.Tensor,
        v_val: torch.Tensor,
        k_cache: torch.Tensor,
        v_cache: torch.Tensor,
    ) -> None:
        """
        assigns kv and returns the original kv

        batch_idx: [1, L]
        input_pos: [1, L]
        k_val: [1, H, L, D]
        v_val: [1, H, L, D]
        k_cache: [1, H, MAX_S, D]
        v_cache: [1, H, MAX_S, D]
        """

        assert batch_idx.ndim == 2, "batch_idx must be a 2D tensor"
        assert input_pos.ndim == 2, "input_pos must be a 2D tensor"
        assert k_val.ndim == 4, "k_val must be a 4D tensor"
        assert v_val.ndim == 4, "v_val must be a 4D tensor"
        assert k_cache.ndim == 4, "k_cache must be a 4D tensor"
        assert v_cache.ndim == 4, "v_cache must be a 4D tensor"
        assert batch_idx.shape[0] == 1, "batch_idx must have batch size 1"

        input_pos_block_idx = input_pos // self.page_size
        input_pos_offset_in_block = input_pos % self.page_size
        physical_kv_idx = self.page_table[batch_idx, input_pos_block_idx] * self.page_size + input_pos_offset_in_block
        k_cache[:, :, physical_kv_idx.view(-1), :] = k_val
        v_cache[:, :, physical_kv_idx.view(-1), :] = v_val

        return k_val, v_val


================================================
FILE: plot_metrics.py
================================================
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "pandas",
#     "matplotlib",
# ]
# ///

import pandas as pd
import matplotlib.pyplot as plt

# Read the CSV files
flex_nano_df = pd.read_csv('flex_nano_vllm_metrics.csv')
vllm_df = pd.read_csv('vllm_metrics.csv')

# Create figure with subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

# Plot 1: Running requests comparison
ax1.plot(flex_nano_df['step'], flex_nano_df['requests_running'], 
         label='Flex Nano VLLM', color='blue', linewidth=1.5)
ax1.plot(vllm_df['steps'], vllm_df['requests_running'], 
         label='VLLM', color='red', linewidth=1.5)
ax1.set_title('Running Requests Over Time')
ax1.set_xlabel('Step')
ax1.set_ylabel('Running Requests')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Waiting requests comparison
ax2.plot(flex_nano_df['step'], flex_nano_df['requests_waiting'], 
         label='Flex Nano VLLM', color='blue', linewidth=1.5)
ax2.plot(vllm_df['steps'], vllm_df['requests_waiting'], 
         label='VLLM', color='red', linewidth=1.5)
ax2.set_title('Waiting Requests Over Time')
ax2.set_xlabel('Step')
ax2.set_ylabel('Waiting Requests')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Flex Nano VLLM step types
prefill_steps = flex_nano_df[flex_nano_df['step_type'] == 'prefill']
decode_steps = flex_nano_df[flex_nano_df['step_type'] == 'decode']

ax3.scatter(prefill_steps['step'], prefill_steps['requests_running'], 
           label='Prefill', alpha=0.6, s=10, color='green')
ax3.scatter(decode_steps['step'], decode_steps['requests_running'], 
           label='Decode', alpha=0.6, s=10, color='orange')
ax3.set_title('Flex Nano VLLM: Running Requests by Step Type')
ax3.set_xlabel('Step')
ax3.set_ylabel('Running Requests')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Total requests (running + waiting)
flex_nano_total = flex_nano_df['requests_running'] + flex_nano_df['requests_waiting']
vllm_total = vllm_df['requests_running'] + vllm_df['requests_waiting']

ax4.plot(flex_nano_df['step'], flex_nano_total, 
         label='Flex Nano VLLM Total', color='blue', linewidth=1.5)
ax4.plot(vllm_df['steps'], vllm_total, 
         label='VLLM Total', color='red', linewidth=1.5)
ax4.set_title('Total Requests (Running + Waiting)')
ax4.set_xlabel('Step')
ax4.set_ylabel('Total Requests')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('metrics_comparison.png', dpi=300, bbox_inches='tight')
print("Metrics comparison saved as 'metrics_comparison.png'")
plt.show()

# Print some summary statistics
print("\n=== Summary Statistics ===")
print(f"Flex Nano VLLM - Max running: {flex_nano_df['requests_running'].max()}")
print(f"VLLM - Max running: {vllm_df['requests_running'].max()}")
print(f"Flex Nano VLLM - Max waiting: {flex_nano_df['requests_waiting'].max()}")
print(f"VLLM - Max waiting: {vllm_df['requests_waiting'].max()}")
print(f"Flex Nano VLLM - Total steps: {len(flex_nano_df)}")
print(f"VLLM - Total steps: {len(vllm_df)}")

================================================
FILE: pyproject.toml
================================================
[project]
name = "flex-nano-vllm"
version = "0.1.0"
description = "Flex-attention based nano-vllm implementation for fast PaliGemma inference"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
    "accelerate>=1.9.0",
    "datasets>=3.0.0",
    "hf-transfer>=0.1.9",
    "matplotlib>=3.10.3",
    "torch>=2.7.1",
    "tqdm>=4.67.1",
    "transformers>=4.53.2",
    "triton>=3.3.1",
]

[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = ["flex_nano_vllm"]

[dependency-groups]
dev = [
    "rich>=14.1.0",
]

[tool.uv.sources]
transformers = { git = "https://github.com/huggingface/transformers", rev = "34133d0a" }


================================================
FILE: visualize.py
================================================
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "matplotlib",
#     "numpy",
# ]
# ///

import matplotlib.pyplot as plt
import numpy as np

# Data
configs = ['50% GPU', '90% GPU', '90% GPU\n(high batch)']
vllm_output = [3020, 3772, 3840]
flex_output = [2313, 3076, 3440]

# Create figure
fig, ax = plt.subplots(figsize=(12, 8))

x = np.arange(len(configs))
width = 0.35

bars1 = ax.bar(x - width/2, vllm_output, width, label='vLLM v1', color='#1f77b4', alpha=0.8)
bars2 = ax.bar(x + width/2, flex_output, width, label='flex-nano-vllm', color='#ff7f0e', alpha=0.8)

ax.set_title('Output Tokens/s Comparison by Configuration', fontsize=16, fontweight='bold', pad=20)
ax.set_ylabel('Tokens/s', fontsize=14)
ax.set_xlabel('GPU Memory Configuration', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(configs, fontsize=12)
ax.legend(fontsize=12)
ax.grid(axis='y', alpha=0.3)

# Add value labels
for bar in bars1:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 50,
            f'{int(height)}', ha='center', va='bottom', fontweight='bold', fontsize=11)

# Add value labels with percentages for flex-nano-vllm
for i, bar in enumerate(bars2):
    height = bar.get_height()
    percentage = (flex_output[i] / vllm_output[i]) * 100
    ax.text(bar.get_x() + bar.get_width()/2., height + 50,
            f'{int(height)}\n({percentage:.1f}%)', ha='center', va='bottom', fontweight='bold', fontsize=11)

plt.tight_layout()

# Save the plot
plt.savefig('tokens_per_second_comparison.png', dpi=300, bbox_inches='tight')
print("Simple comparison saved as 'tokens_per_second_comparison.png'")

plt.show()
Download .txt
gitextract_2vld2n29/

├── .gitignore
├── LICENSE
├── README.md
├── THIRD_PARTY_LICENSES.md
├── benchmark.py
├── benchmark_vllm.py
├── flex_nano_vllm/
│   ├── __init__.py
│   ├── inference.py
│   ├── modeling_gemma2.py
│   └── paged_attention.py
├── plot_metrics.py
├── pyproject.toml
└── visualize.py
Download .txt
SYMBOL INDEX (89 symbols across 5 files)

FILE: benchmark.py
  function get_profiler_context (line 50) | def get_profiler_context():
  function debug_print_outputs (line 67) | def debug_print_outputs(sequences, tokenizer, slice=slice(None), referen...

FILE: benchmark_vllm.py
  class BenchmarkRequest (line 26) | class BenchmarkRequest:
  function generate_benchmark_data (line 39) | def generate_benchmark_data(tokenizer, n_requests=512, max_input_length=...
  function print_step_stats (line 75) | def print_step_stats(steps, name):
  function generate_with_timing (line 87) | def generate_with_timing(llm, sequences, sampling_params, collect_detail...

FILE: flex_nano_vllm/inference.py
  class SamplingParams (line 19) | class SamplingParams:
  function sample (line 23) | def sample(logits_BV, greedy=True, to_cpu=False):
  class Sequence (line 43) | class Sequence:
    method __init__ (line 44) | def __init__(self, text: str):
    method add_next_token (line 56) | def add_next_token(self, token_id: torch.Tensor, logits: torch.Tensor,...
    method copy (line 63) | def copy(self):
    method output_ids (line 67) | def output_ids(self):
    method output_logits (line 71) | def output_logits(self):
    method output_probs (line 75) | def output_probs(self):
    method output_length (line 79) | def output_length(self):
    method total_length (line 83) | def total_length(self):
    method total_token_ids (line 87) | def total_token_ids(self):
    method last_token_id (line 93) | def last_token_id(self):
  function process_sampling_params (line 97) | def process_sampling_params(sequences: list[Sequence], sampling_params: ...
  class Inference (line 109) | class Inference:
    method __init__ (line 110) | def __init__(self, model, tokenizer, max_batch_size, max_seq_length, n...
    method _prefill_sequences (line 136) | def _prefill_sequences(
    method prefill_sequences (line 184) | def prefill_sequences(self, sequences: list[Sequence]) -> torch.Tensor:
    method get_decoding_block_mask (line 206) | def get_decoding_block_mask(self, batch_idx: torch.Tensor):
    method _decode_step (line 252) | def _decode_step(self, batch_idx: torch.Tensor, input_ids: torch.Tensor):
    method decode_step (line 266) | def decode_step(self, batch_idx: torch.Tensor, input_ids: torch.Tensor...
    method _check_done (line 287) | def _check_done(self, sequences: list[Sequence]):
    method run_one_step (line 298) | def run_one_step(self):
    method tokenize (line 362) | def tokenize(self, sequences: list[Sequence]):
    method generate (line 369) | def generate(
    method capture_decode_cudagraph (line 430) | def capture_decode_cudagraph(self):
    method print_time_stats (line 464) | def print_time_stats(self, times):

FILE: flex_nano_vllm/modeling_gemma2.py
  class Gemma2RMSNorm (line 54) | class Gemma2RMSNorm(nn.Module):
    method __init__ (line 55) | def __init__(self, dim: int, eps: float = 1e-6):
    method _norm (line 60) | def _norm(self, x):
    method forward (line 63) | def forward(self, x):
    method extra_repr (line 70) | def extra_repr(self):
  class Gemma2MLP (line 73) | class Gemma2MLP(nn.Module):
    method __init__ (line 74) | def __init__(self, config):
    method forward (line 84) | def forward(self, x):
  function rotate_half (line 89) | def rotate_half(x):
  function apply_rotary_pos_emb (line 96) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di...
  function repeat_kv (line 123) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  function eager_attention_forward (line 135) | def eager_attention_forward(
  class Gemma2Attention (line 170) | class Gemma2Attention(nn.Module):
    method __init__ (line 173) | def __init__(self, config: Gemma2Config, layer_idx: int):
    method forward (line 199) | def forward(
  class Gemma2DecoderLayer (line 252) | class Gemma2DecoderLayer(GradientCheckpointingLayer):
    method __init__ (line 253) | def __init__(self, config: Gemma2Config, layer_idx: int):
    method forward (line 267) | def forward(
  class Gemma2RotaryEmbedding (line 312) | class Gemma2RotaryEmbedding(nn.Module):
    method __init__ (line 313) | def __init__(self, config: Gemma2Config, device=None):
    method forward (line 332) | def forward(self, x, position_ids):
  class Gemma2PreTrainedModel (line 347) | class Gemma2PreTrainedModel(PreTrainedModel):
    method _init_weights (line 364) | def _init_weights(self, module):
  class Gemma2Model (line 379) | class Gemma2Model(Gemma2PreTrainedModel):
    method __init__ (line 380) | def __init__(self, config: Gemma2Config):
    method get_input_embeddings (line 396) | def get_input_embeddings(self):
    method set_input_embeddings (line 399) | def set_input_embeddings(self, value):
    method forward (line 404) | def forward(
  class Gemma2ForCausalLM (line 498) | class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
    method __init__ (line 503) | def __init__(self, config):
    method get_input_embeddings (line 512) | def get_input_embeddings(self):
    method set_input_embeddings (line 515) | def set_input_embeddings(self, value):
    method get_output_embeddings (line 518) | def get_output_embeddings(self):
    method set_output_embeddings (line 521) | def set_output_embeddings(self, new_embeddings):
    method set_decoder (line 524) | def set_decoder(self, decoder):
    method get_decoder (line 527) | def get_decoder(self):
    method forward (line 532) | def forward(

FILE: flex_nano_vllm/paged_attention.py
  function _cdiv (line 23) | def _cdiv(x: int | float | torch.Tensor, multiple: int | float | torch.T...
  class PagedKVCache (line 27) | class PagedKVCache(torch.nn.Module):
    method __init__ (line 28) | def __init__(self, page_table, n_heads, head_dim, dtype):
    method update (line 36) | def update(self, input_pos, k_val, v_val, batch_idx=None):
  class PageTable (line 48) | class PageTable:
    method __init__ (line 59) | def __init__(
    method can_reserve (line 83) | def can_reserve(self, size: int, batch_idx_int: int | None = None) -> ...
    method allocate (line 92) | def allocate(self) -> int:
    method pages_available (line 102) | def pages_available(self) -> int:
    method reserve (line 105) | def reserve(self, batch_idx_int: int, batch_idx: torch.Tensor, seq_len...
    method erase (line 156) | def erase(self, batch_idx: int) -> None:
    method assign (line 169) | def assign(
    method convert_logical_block_mask (line 228) | def convert_logical_block_mask(
    method get_logical_kv_idx (line 296) | def get_logical_kv_idx(self, physical_batch_idx: torch.Tensor, physica...
    method get_mask_mod (line 306) | def get_mask_mod(self, mask_mod: Optional[_mask_mod_signature], batch_...
    method get_score_mod (line 329) | def get_score_mod(self, score_mod: Optional[_score_mod_signature], bat...
    method create_causal_blockmask (line 356) | def create_causal_blockmask(self, B, L):
    method create_prefill_blockmask_no_paging (line 364) | def create_prefill_blockmask_no_paging(self, batch_idx: Tensor, BLOCK_...
    method assign_prefill_no_paging (line 381) | def assign_prefill_no_paging(
Condensed preview — 13 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (100K chars).
[
  {
    "path": ".gitignore",
    "chars": 42,
    "preview": "__pycache__/\n*.egg-info/\ntrace_dir/\n*.csv\n"
  },
  {
    "path": "LICENSE",
    "chars": 1070,
    "preview": "MIT License\n\nCopyright (c) 2025 Jonathan Chang\n\nPermission is hereby granted, free of charge, to any person obtaining a "
  },
  {
    "path": "README.md",
    "chars": 3121,
    "preview": "# flex-nano-vllm\n\nFlexAttention based, minimal vllm-style inference engine for fast Gemma 2 inference.\n\n## Introduction\n"
  },
  {
    "path": "THIRD_PARTY_LICENSES.md",
    "chars": 2785,
    "preview": "# Third Party Licenses\n\nThis project incorporates code from third-party open source projects. The following licenses app"
  },
  {
    "path": "benchmark.py",
    "chars": 11131,
    "preview": "\"\"\"\nThis script is used to test the correctness and benchmark the paged attention implementation.\n\n\nsimplified interface"
  },
  {
    "path": "benchmark_vllm.py",
    "chars": 7176,
    "preview": "# /// script\n# requires-python = \">=3.12\"\n# dependencies = [\n#     \"vllm\",\n#     \"transformers\",\n#     \"datasets\",\n#    "
  },
  {
    "path": "flex_nano_vllm/__init__.py",
    "chars": 270,
    "preview": "\"\"\"flex-nano-vllm - Flex-attention based nano-vllm implementation for fast PaliGemma inference.\"\"\"\n\nfrom .modeling_gemma"
  },
  {
    "path": "flex_nano_vllm/inference.py",
    "chars": 21753,
    "preview": "from collections import deque\nimport time\nimport torch\nfrom tqdm import tqdm\nfrom torch.nn.attention.flex_attention impo"
  },
  {
    "path": "flex_nano_vllm/modeling_gemma2.py",
    "chars": 25835,
    "preview": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from"
  },
  {
    "path": "flex_nano_vllm/paged_attention.py",
    "chars": 17789,
    "preview": "# Adapted from attention-gym\n# Original source: https://github.com/pytorch-labs/attention-gym\n# License: BSD 3-Clause (s"
  },
  {
    "path": "plot_metrics.py",
    "chars": 3002,
    "preview": "# /// script\n# requires-python = \">=3.12\"\n# dependencies = [\n#     \"pandas\",\n#     \"matplotlib\",\n# ]\n# ///\n\nimport panda"
  },
  {
    "path": "pyproject.toml",
    "chars": 704,
    "preview": "[project]\nname = \"flex-nano-vllm\"\nversion = \"0.1.0\"\ndescription = \"Flex-attention based nano-vllm implementation for fas"
  },
  {
    "path": "visualize.py",
    "chars": 1643,
    "preview": "# /// script\n# requires-python = \">=3.12\"\n# dependencies = [\n#     \"matplotlib\",\n#     \"numpy\",\n# ]\n# ///\n\nimport matplo"
  }
]

About this extraction

This page contains the full source code of the changjonathanc/flex-nano-vllm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 13 files (94.1 KB), approximately 23.6k tokens, and a symbol index with 89 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!