[
  {
    "path": ".gitignore",
    "content": "__pycache__/\n*.egg-info/\ntrace_dir/\n*.csv\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2025 Jonathan Chang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# flex-nano-vllm\n\nFlexAttention based, minimal vllm-style inference engine for fast Gemma 2 inference.\n\n## Introduction\n\nThis project has no flash-attn dependency, no custom triton kernel. Everything is implemented with FlexAttention. The code is commented, the structure is flat. Read the accompanying write-up: [vLLM flex attention from scratch](https://jonathanc.net/blog/vllm-flex-attention-from-scratch).\n\n## Code Structure\n\n```\nflex-nano-vllm/\n├── benchmark.py                   # Testing and benchmarking script.\n├── benchmark_vllm.py              # vLLM comparison benchmark (uses uv inline dependency to run vLLM).\n├── visualize.py                   # Performance visualization script.\n└── flex_nano_vllm/\n    ├── inference.py               # Main inference engine, uses paged attention.\n    ├── modeling_gemma2.py         # Gemma2 model implementation, copied from transformers.\n    └── paged_attention.py         # Paged attention implementation, including page table and paged kv cache. Based on attention-gym.\n```\n\n## Quick Start\n\n```\nuv sync\n\n# run test and benchmark\nuv run benchmark.py\n\n# compare with vllm\nuv run benchmark_vllm.py\n\n# enable profiling to save more metrics to a csv file\n# ENABLE_PROFILING=1 uv run benchmark_vllm.py\n```\n\n\n## Results\n\nTest configuration:\n- PyTorch version: 2.7.1+cu128\n- GPU: RTX 3090 x 1 (24GB)\n- Model: google/gemma-2-2b\n- Workload: 512 requests, max 512 input tokens, variable output tokens (128-512)\n- Configs tested: vLLM at 50% & 90% GPU memory, flex-nano-vllm with same page allocation as vLLM\n\n| Implementation | Output Tokens/s | Request/s | Total Throughput* |\n|---------------|----------------|-----------|------------------|\n| vLLM v1, 90% GPU memory, high batch size† | 3,840 | 17.67 | 7,234 | \n| vLLM v1, 90% GPU memory | 3,772 | 15.26 | 6,401 | \n| flex-nano-vllm, 90% GPU memory, high batch size† | 3,440 | 14.30 | 5,817 |\n| flex-nano-vllm, 90% GPU memory | 3,076 | 13.06 | 5,382 |\n| vLLM v1, 50% GPU memory | 3,020 | 13.74 | 5,448 | \n| flex-nano-vllm, 50% GPU memory | 2,313 | 9.96 | 4,068 |\n\n*Total throughput includes both input and output tokens  \n† High batch size means max_num_seqs=512 in vllm (maximum allowed concurrency)\n\n![Performance Comparison](tokens_per_second_comparison.png)\n\n## License\n\nThis project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.\n\nThird-party code incorporated in this project retains its original licenses. See [THIRD_PARTY_LICENSES.md](THIRD_PARTY_LICENSES.md) for details.\n\n## Acknowledgments\n\n- [GeeeekExplorer/nano-vllm](https://github.com/GeeeekExplorer/nano-vllm): this project is inspired by nano-vllm.\n- [pytorch-labs/attention-gym](https://github.com/pytorch-labs/attention-gym): The paged attention implementation is based on attention-gym.\n- [huggingface/transformers](https://github.com/huggingface/transformers): I copied the gemma2 model from transformers and modified it to use flex attention / paged attention.\n- [vllm-project/vllm](https://github.com/vllm-project/vllm): vLLM has support for flex attention backend, which helped me find a useful flag in flex_attention.\n"
  },
  {
    "path": "THIRD_PARTY_LICENSES.md",
    "content": "# Third Party Licenses\n\nThis project incorporates code from third-party open source projects. The following licenses apply to the respective components:\n\n## Hugging Face Transformers\n\nThis project includes modified code from the transformers project:\n- **Source**: https://github.com/huggingface/transformers\n- **Files**: `modeling_gemma2.py` (Gemma2 model implementation)\n- **License**: Apache License 2.0\n- **Copyright**: Copyright 2024 Google Inc. HuggingFace Inc. team\n\n### Apache License 2.0\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n## Attention Gym\n\nThis project includes modified code from the attention-gym project:\n- **Source**: https://github.com/pytorch-labs/attention-gym  \n- **Files**: `paged_attention.py` and related components\n- **License**: BSD 3-Clause License\n- **Copyright**: Copyright (c) 2023, Driss Guessous\n\n### BSD 3-Clause License\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its\n   contributors may be used to endorse or promote products derived from\n   this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "benchmark.py",
    "content": "\"\"\"\nThis script is used to test the correctness and benchmark the paged attention implementation.\n\n\nsimplified interface of Inference class:\nUsage:\nllm = Inference(...)\nsequences = [Sequence(text) for text in texts]\nllm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True, capture_cudagraph=True)\noutputs = debug_print_outputs(sequences, tokenizer)\n\n\n\n\n## correctness:\n1. run with different 2 sequences with different lengths, the output should match huggingface .generate()\n        this output has been verified to be correct outside of this script\n2. run the same sequence with .generate() and paged attention, the output should match\n        this is to test the paged attention Inference class does not have side effects that can alter the output across .generate() calls\n\n### correctness with dynamic batching\n3. run with different number of requests, and the same 2 sequences are mixed in the batch\n        the output should match 1.\n\n### correctness with cuda graph\n4. after cudagraph capture, run some batch of requests, and the same 2 sequences are mixed in the batch\n    the output should match 1.\n\n### tests we don't cover, but might be useful to have\n\n- PageTable unit tests\n- tests with output length longer than one page (128 tokens)\n\n\"\"\"\n\nfrom transformers import AutoTokenizer\nimport time\nfrom flex_nano_vllm import Gemma2ForCausalLM\nfrom flex_nano_vllm.inference import Inference, Sequence, SamplingParams\nfrom benchmark_vllm import generate_benchmark_data, long_prompt, short_prompt\n\nimport torch\nfrom rich.console import Console\nfrom torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler\n\nconsole = Console()\ntorch.set_float32_matmul_precision(\"high\")\n\n\ndef get_profiler_context():\n    activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]\n    profiler_context = profile(\n        activities=activities,\n        schedule=schedule(wait=0, warmup=10, active=10, repeat=10),\n        on_trace_ready=tensorboard_trace_handler(\"trace_dir\"),\n        record_shapes=False,\n        profile_memory=False,\n        with_stack=True,\n        with_flops=False,\n    )\n    return profiler_context\n\n\n# long_prompt and short_prompt are now imported from bench_utils\n\n\ndef debug_print_outputs(sequences, tokenizer, slice=slice(None), reference=None, prefix_match=False):\n    results = []\n    for i, seq in enumerate(sequences[slice]):\n        output_ids = seq.output_ids\n        output_decoded = tokenizer.decode(output_ids, skip_special_tokens=True)\n        input_decoded = tokenizer.decode(seq.input_ids, skip_special_tokens=False)\n\n        # If reference provided, only print on mismatch\n        if reference is not None and i < len(reference):\n            # Check for exact match or prefix match\n            if prefix_match:\n                matches = output_decoded.startswith(reference[i])\n            else:\n                matches = output_decoded == reference[i]\n\n            if matches:\n                match_type = \"prefix match\" if prefix_match else \"match\"\n                console.print(f\"i={i} ✓ {match_type}\", style=\"green\", markup=False)\n            else:\n                mismatch_type = \"PREFIX MISMATCH\" if prefix_match else \"MISMATCH\"\n                console.print(f\"i={i} ✗ {mismatch_type}\", style=\"red bold\", markup=False)\n                console.print(f\"  expected: {reference[i][:32]}...\", style=\"red\", markup=False)\n                console.print(f\"  got:      {output_decoded[:32]}...\", style=\"red\", markup=False)\n                console.print(f\"  input:    {input_decoded}\", style=\"dim\", markup=False)\n        else:\n            # Normal detailed output when no reference\n            console.print(f\"{i=} {input_decoded=} {output_decoded[:32]=}\", style=\"bold \", markup=False)\n\n        results.append(output_decoded)\n    return results\n\n\nif __name__ == \"__main__\":\n    # Load model and tokenizer\n    model_id = \"google/gemma-2-2b\"\n    tokenizer = AutoTokenizer.from_pretrained(model_id)\n\n    model = Gemma2ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=\"auto\").eval()\n    # model = torch.compile(model)\n\n    B = 8\n    max_new_tokens = 8\n    # settings to match vLLM (see benchmark_vllm.py)\n    paged_attn_max_batch_size = 256  # match vLLM max_num_seqs=256\n    max_seq_length = 2048\n    max_input_length = 1024\n    token_allocation = 45_360  # 50% memory usage, this number is derived by running uv run benchmark_vllm.py and checking the logs (on a 3090)\n    token_allocation = 140_432  # 90% memory usage, this number is derived by running uv run benchmark_vllm.py and checking the logs (on a 3090)\n\n    prefill_length_limit = 1024 * 8  # helps control peak memory usage for prefill\n\n    page_size = 128\n    n_pages = int(token_allocation) // page_size\n\n    print(\"initializing vllm inference\")\n    llm = Inference(\n        model,\n        tokenizer,\n        max_batch_size=paged_attn_max_batch_size,\n        max_seq_length=max_seq_length,\n        n_pages=n_pages,\n        kernel_options={\"BLOCK_M\": 32, \"BLOCK_N\": 32},\n        prefill_length_limit=prefill_length_limit,\n    )\n\n    print(\"test 1\")\n    ## test 1\n    sequences = [Sequence(long_prompt), Sequence(short_prompt)]\n    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True)\n    results = debug_print_outputs(sequences, tokenizer)\n    del sequences\n    torch.cuda.empty_cache()\n\n    print(\"test 2, same batch\")\n    sequences = [Sequence(long_prompt), Sequence(short_prompt)]\n    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True)\n    results2 = debug_print_outputs(sequences, tokenizer, reference=results)\n    for i in range(len(results)):\n        assert results[i] == results2[i], f\"{i=}, {results[i]=}, {results2[i]=}\"\n    del sequences\n    torch.cuda.empty_cache()\n\n    print(\"test 2.1: reverse order\")\n    sequences = [Sequence(short_prompt), Sequence(long_prompt)]\n    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True)\n    # Debug in the right order: [long_prompt, short_prompt] to match reference\n    reordered_sequences = [sequences[1], sequences[0]]  # [long_prompt, short_prompt]\n    results21 = debug_print_outputs(reordered_sequences, tokenizer, reference=results)\n    for i in range(len(results)):\n        assert results[i] == results21[i], f\"{i=}, {results[i]=}, {results21[i]=}\"\n    del sequences\n    torch.cuda.empty_cache()\n\n    print(\"test 3: batch with other sequence\")\n    sequences = [Sequence(short_prompt), Sequence(long_prompt), Sequence(\"hi\")]\n    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True)\n    # Debug just the sequences we care about in the right order: [long_prompt, short_prompt]\n    comparison_sequences = [sequences[1], sequences[0]]  # [long_prompt, short_prompt]\n    results3 = debug_print_outputs(comparison_sequences, tokenizer, reference=results)\n    del sequences\n    for i in range(len(results)):\n        assert results[i] == results3[i], f\"{i=}, {results[i]=}, {results3[i]=}\"\n    torch.cuda.empty_cache()\n\n    print(\"test 4: batch with other sequence, capture cudagraph\")\n    sequences = [\n        Sequence(\"this is a test messaage hello \"),\n        Sequence(short_prompt),\n        Sequence(\"test\"),\n        Sequence(long_prompt),\n        Sequence(\"hello world \"),\n    ]\n    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True, capture_cudagraph=True)\n    # Debug just the sequences we care about in the right order: [long_prompt, short_prompt]\n    comparison_sequences = [sequences[3], sequences[1]]  # [long_prompt, short_prompt]\n    results4 = debug_print_outputs(comparison_sequences, tokenizer, reference=results)\n    del sequences\n    for i in range(len(results)):\n        assert results[i] == results4[i], f\"{i=}, {results[i]=}, {results4[i]=}\"\n    torch.cuda.empty_cache()\n\n    # Generate test batch for cudagraph\n    test_requests = generate_benchmark_data(tokenizer, n_requests=4, max_input_length=max_input_length)\n    sequences = [Sequence(req.text) for req in test_requests]\n    llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=16), use_tqdm=True, capture_cudagraph=True)\n    # Just capture cudagraph, no need to show verbose output\n    # debug_print_outputs(sequences, tokenizer)\n\n    print(\"after cudagraph\")\n    test_requests2 = generate_benchmark_data(tokenizer, n_requests=4, max_input_length=max_input_length)\n    sequences = [Sequence(req.text) for req in test_requests2] + [Sequence(long_prompt), Sequence(short_prompt)]\n    print(\"replay cudagraph, & profile\")\n    with get_profiler_context() as prof:\n        llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=16), use_tqdm=True, profiler=prof)\n    results_prefill = debug_print_outputs(sequences, tokenizer, slice=slice(-2, None), reference=results, prefix_match=True)\n    for i in range(len(results_prefill)):\n        assert results_prefill[i][: len(results[i])] == results[i], f\"{i=}, {results[i]=}, {results_prefill[i][:len(results[i])]=}\"\n    del sequences\n    torch.cuda.empty_cache()\n\n    ## benchmark throughput\n    n_requests = 512\n    max_input_length = 512\n    # Use shared benchmark data generation\n    benchmark_requests = generate_benchmark_data(\n        tokenizer,\n        n_requests=n_requests,\n        max_input_length=max_input_length,\n    )\n\n    # Convert to flex-nano-vllm format\n    sequences = [Sequence(req.text) for req in benchmark_requests]\n    sampling_params = [SamplingParams(max_new_tokens=req.max_new_tokens) for req in benchmark_requests]\n\n    print(\"\\n--- RUNNING BENCHMARK ---\")\n\n    # Reset memory stats to track benchmark-specific usage\n    torch.cuda.reset_peak_memory_stats()\n\n    start_time = time.time()\n    llm.generate(sequences, sampling_params=sampling_params, use_tqdm=False, save_metrics_csv=True, print_stats=True)\n    total_time = time.time() - start_time\n\n    total_output_length = sum(seq.output_length for seq in sequences)\n    total_input_length = sum(len(seq.input_ids) for seq in sequences)\n\n    # Get memory usage\n    peak_memory_mb = torch.cuda.max_memory_allocated() / 1024 / 1024\n    current_memory_mb = torch.cuda.memory_allocated() / 1024 / 1024\n\n    print(\"\\n--- PERFORMANCE METRICS ---\")\n    print(f\"Total time: {total_time:.2f}s\")\n    print(f\"Throughput: {total_output_length / total_time:.1f} tokens/s\")\n    print(f\"Request throughput: {len(sequences) / total_time:.2f} req/s\")\n    print(f\"Total throughput (prompt+new): {(total_input_length + total_output_length) / total_time:.1f} tokens/s\")\n    print(f\"Peak memory: {peak_memory_mb:.1f} MB\")\n    print(f\"Current memory: {current_memory_mb:.1f} MB\")\n\n    print(\"\\nafter benchmark\")\n    results_final = debug_print_outputs(sequences, tokenizer, slice=slice(n_requests, n_requests + 2), reference=results, prefix_match=True)\n\n    # Verify correctness\n    for i in range(len(results_final)):\n        assert results[i] == results_final[i][: len(results[i])], f\"{i=}, {results[i]=}, {results_final[i][:len(results[i])]=}\"\n"
  },
  {
    "path": "benchmark_vllm.py",
    "content": "# /// script\n# requires-python = \">=3.12\"\n# dependencies = [\n#     \"vllm\",\n#     \"transformers\",\n#     \"datasets\",\n#     \"matplotlib\",\n#     \"tqdm\",\n#     \"pandas\",\n# ]\n# ///\n\n# Usage:\n# Run benchmark with minimal overhead: python bench_utils.py\n# Run benchmark with profiling: ENABLE_PROFILING=1 python bench_utils.py\n# NOTE: this script serves 2 purposes:\n# 1. it can be used to benchmark vLLM's performance, run with isolated inline dependencies.\n# 2. outside of __main__, it contains utils for producing the same payload for benchmarking.\n\nfrom tqdm import tqdm\nfrom datasets import load_dataset\nimport random\nfrom dataclasses import dataclass\n\n@dataclass\nclass BenchmarkRequest:\n    \"\"\"Simple, framework-agnostic request data\"\"\"\n    text: str\n    max_new_tokens: int\n\n# Standard prompts used in benchmarks\nlong_prompt = \"\"\"\nThe 12 months of the year are: January, February, March,\n\"\"\".strip()\n\nshort_prompt = \"The first 20 prime numbers are: 2, 3,\"\n\n\ndef generate_benchmark_data(tokenizer, n_requests=512, max_input_length=512, min_tokens=128, max_tokens=512):\n    \"\"\"Generate benchmark data by skipping prompts that are too long.\"\"\"\n    from datasets import load_dataset\n    \n    data = load_dataset(\"Open-Orca/OpenOrca\")[\"train\"]\n    benchmark_requests = []\n    \n    attempts = 0\n    while len(benchmark_requests) < n_requests:\n        # Deterministic sampling using hash\n        idx = hash(f\"req_{attempts}\") % len(data)\n        \n        system_prompt = data[idx][\"system_prompt\"] or \"\"\n        question = data[idx][\"question\"] or \"\"\n        prompt = f\"{idx}: {system_prompt} {question}\".strip()\n        \n        # Check length and skip if too long\n        tokens = tokenizer.encode(prompt)\n        if len(tokens) <= max_input_length:\n            prompt_hash = hash(prompt)\n            max_new_tokens = min_tokens + (abs(prompt_hash >> 16) % (max_tokens - min_tokens + 1))\n            benchmark_requests.append(BenchmarkRequest(text=prompt, max_new_tokens=max_new_tokens))\n        \n        attempts += 1\n        if attempts > n_requests * 10:  # Safety valve to prevent infinite loop\n            break\n    \n    # Add standard prompts\n    benchmark_requests.extend([\n        BenchmarkRequest(text=long_prompt, max_new_tokens=max_tokens),\n        BenchmarkRequest(text=short_prompt, max_new_tokens=max_tokens)\n    ])\n    \n    return benchmark_requests\n\n\ndef print_step_stats(steps, name):\n    \"\"\"Helper to print timing statistics for a collection of steps.\"\"\"\n    if not steps:\n        return\n    print(f\"\\n{name}:\")\n    print(f\"  Count: {len(steps)}\")\n    print(f\"  Total: {sum(steps):.4f}s\")\n    print(f\"  Mean:  {sum(steps)/len(steps):.4f}s\")\n    print(f\"  Min:   {min(steps):.4f}s\")\n    print(f\"  Max:   {max(steps):.4f}s\")\n\n\ndef generate_with_timing(llm, sequences, sampling_params, collect_detailed_metrics=False):\n    \"\"\"\n    Generate with timing, optionally collecting detailed metrics.\n    \n    Note: We track total step time rather than trying to separate prefill/decode\n    because vLLM can do both types of work within a single step, making such\n    separation misleading for performance analysis.\n    \"\"\"\n    outputs = []\n    total_step_time = 0.0\n    step_times = []\n    \n    # Optional detailed metrics\n    metrics_data = {} if not collect_detailed_metrics else {\n        'steps': [], 'requests_running': [], 'requests_waiting': [], 'preemptions': []\n    }\n    \n    # Add requests\n    for i, prompt in enumerate(sequences):\n        sp = sampling_params[i] if isinstance(sampling_params, list) else sampling_params\n        llm.llm_engine.add_request(str(i), prompt, sp)\n    \n    step_count = 0\n    \n    while llm.llm_engine.has_unfinished_requests():\n        step_start = time.perf_counter()\n        step_outputs = llm.llm_engine.step()\n        step_duration = time.perf_counter() - step_start\n        step_count += 1\n        \n        total_step_time += step_duration\n        step_times.append(step_duration)\n        \n        # Collect outputs\n        for output in step_outputs:\n            if output.finished:\n                outputs.append(output)\n                \n        # Optional detailed metrics\n        if collect_detailed_metrics:\n            metrics = llm.llm_engine.get_metrics()\n            metrics_data['steps'].append(step_count)\n            # Extract key metrics\n            running = waiting = preemptions = 0\n            for metric in metrics:\n                if \"requests_running\" in metric.name:\n                    running = metric.value\n                elif \"requests_waiting\" in metric.name:\n                    waiting = metric.value\n                elif \"preemptions\" in metric.name:\n                    preemptions = metric.value\n            metrics_data['requests_running'].append(running)\n            metrics_data['requests_waiting'].append(waiting)\n            metrics_data['preemptions'].append(preemptions)\n    \n    return total_step_time, step_times, outputs, metrics_data\n\n\nif __name__ == \"__main__\":\n    import os\n    import time\n    \n    os.environ[\"VLLM_TORCH_PROFILER_DIR\"] = \"./vllm_profile\"\n    os.environ['VLLM_USE_V1'] = '1'\n    \n    ENABLE_PROFILING = os.environ.get(\"ENABLE_PROFILING\", \"0\") == \"1\"\n    \n    from vllm import LLM, SamplingParams as VLLMSamplingParams\n    from transformers import AutoTokenizer\n\n    MODEL_ID = \"google/gemma-2-2b\"\n    llm = LLM(MODEL_ID, dtype=\"bfloat16\", gpu_memory_utilization=0.9, max_num_seqs=256, max_model_len=2048, disable_log_stats=False)\n    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n    \n    n_requests = 512\n    max_input_length = 512\n\n    benchmark_requests = generate_benchmark_data(tokenizer, n_requests, max_input_length)\n    \n    sequences = [req.text for req in benchmark_requests]\n    sampling_params_list = [\n        VLLMSamplingParams(temperature=0.0, top_p=1.0, max_tokens=req.max_new_tokens)\n        for req in benchmark_requests\n    ]\n    \n    # Warmup\n    llm.generate([\"warmup\"], VLLMSamplingParams(max_tokens=1), use_tqdm=False)\n\n\n\n    print(f\"\\n--- RUNNING {'WITH' if ENABLE_PROFILING else 'WITHOUT'} DETAILED METRICS ---\")\n    \n    # Reset memory stats to track benchmark-specific usage  \n    start_time = time.time()\n    total_step_time, step_times, outputs, metrics_data = generate_with_timing(\n        llm, sequences, sampling_params_list, collect_detailed_metrics=ENABLE_PROFILING\n    )\n    total_time = time.time() - start_time\n\n    total_output_length = sum(len(o.outputs[0].token_ids) for o in outputs)\n    prompt_tok = sum(len(o.prompt_token_ids) for o in outputs)\n    \n    print_step_stats(step_times, \"\\nstep\")\n    \n    # Get memory usage\n    print(\"\\n--- PERFORMANCE METRICS ---\")\n    print(f\"Total time: {total_time:.2f}s\")\n    print(f\"Throughput: {total_output_length / total_time:.1f} tokens/s\")\n    print(f\"Request throughput: {len(sequences) / total_time:.2f} req/s\")\n    print(f\"Total throughput (prompt+new): {(prompt_tok + total_output_length) / total_time:.1f} tokens/s\")\n        \n    if ENABLE_PROFILING and metrics_data:\n        print(\"\\n--- DETAILED METRICS ---\")\n        import pandas as pd\n        pd.DataFrame(metrics_data).to_csv('vllm_metrics.csv', index=False)\n        print(\"Metrics saved to 'vllm_metrics.csv'\")\n\n"
  },
  {
    "path": "flex_nano_vllm/__init__.py",
    "content": "\"\"\"flex-nano-vllm - Flex-attention based nano-vllm implementation for fast PaliGemma inference.\"\"\"\n\nfrom .modeling_gemma2 import Gemma2ForCausalLM\nfrom .inference import Inference, Sequence\n\n__version__ = \"0.1.0\"\n__all__ = [\"Gemma2ForCausalLM\", \"Inference\", \"Sequence\"]\n"
  },
  {
    "path": "flex_nano_vllm/inference.py",
    "content": "from collections import deque\nimport time\nimport torch\nfrom tqdm import tqdm\nfrom torch.nn.attention.flex_attention import BlockMask\nimport torch.nn.attention.flex_attention\nimport torch.nn.functional as F\nfrom rich.console import Console\n\nfrom flex_nano_vllm.paged_attention import PageTable, PagedKVCache\n\nfrom dataclasses import dataclass\n\nconsole = Console()\nprint(f\"torch version: {torch.__version__}\")\n\n\n@dataclass\nclass SamplingParams:\n    max_new_tokens: int = -1\n\n\ndef sample(logits_BV, greedy=True, to_cpu=False):\n    # NOTE: use greedy=True to ensure deterministic sampling\n    assert logits_BV.ndim == 2\n    B, V = logits_BV.shape\n    probs = torch.softmax(logits_BV, dim=-1)\n    if not greedy:\n        indices = torch.multinomial(probs, num_samples=1)  # shape: [B, 1]\n        logits = torch.gather(logits_BV, dim=-1, index=indices)\n        probs = torch.gather(probs, dim=-1, index=indices)\n    else:\n        probs, indices = torch.topk(probs, k=1, dim=-1)\n        logits = torch.gather(logits_BV, dim=-1, index=indices)\n    if to_cpu:\n        indices = indices.to(\"cpu\", non_blocking=True).view(B)\n        logits = logits.to(\"cpu\", non_blocking=True).view(B)\n        probs = probs.to(\"cpu\", non_blocking=True).view(B)\n        torch.cuda.synchronize()\n    return indices.tolist(), logits.tolist(), probs.tolist()\n\n\nclass Sequence:\n    def __init__(self, text: str):\n        self.done = False\n        self.text = text\n        self._output_ids = []\n        self._output_logits = []\n        self._output_probs = []\n        self.input_ids = []\n        self.finished = False\n\n        self.input_length = None\n        self.inputs = None\n\n    def add_next_token(self, token_id: torch.Tensor, logits: torch.Tensor, probs: torch.Tensor):\n        #assert token_id.ndim == 0\n        #assert logits.ndim == 0\n        self._output_ids.append(token_id)\n        self._output_logits.append(logits)\n        self._output_probs.append(probs)\n\n    def copy(self):\n        return Sequence(self.text)\n\n    @property\n    def output_ids(self):\n        return torch.tensor(self._output_ids, dtype=torch.int64)\n\n    @property\n    def output_logits(self):\n        return torch.tensor(self._output_logits, dtype=torch.float32)\n\n    @property\n    def output_probs(self):\n        return torch.tensor(self._output_probs, dtype=torch.float32)\n\n    @property\n    def output_length(self):\n        return len(self._output_ids)\n\n    @property\n    def total_length(self):\n        return self.input_length + self.output_length\n\n    @property\n    def total_token_ids(self):\n        if self.output_length:\n            return torch.cat([self.input_ids, self.output_ids], dim=0)\n        return self.input_ids\n\n    @property\n    def last_token_id(self):\n        return self._output_ids[-1]\n\n\ndef process_sampling_params(sequences: list[Sequence], sampling_params: SamplingParams | list[SamplingParams] | None):\n    if sampling_params is None:\n        sampling_params = SamplingParams()\n    if isinstance(sampling_params, SamplingParams):\n        sampling_params = [sampling_params] * len(sequences)\n\n    assert len(sampling_params) == len(sequences), \"sampling_params must be a list of the same length as sequences\"\n\n    for seq, param in zip(sequences, sampling_params):\n        seq.params = param\n\n\nclass Inference:\n    def __init__(self, model, tokenizer, max_batch_size, max_seq_length, n_pages, page_size=128, prefill_length_limit=-1, kernel_options=None):\n        self.page_table = PageTable(n_pages=n_pages, page_size=page_size, max_batch_size=max_batch_size)\n\n        self.model = model\n        self.tokenizer = tokenizer\n        self.eos_token_id = tokenizer.eos_token_id # cache this because it's not efficient to call tokenizer.eos_token_id every time\n        self.device = model.device\n        assert max_seq_length % page_size == 0, \"max_seq_length must be divisible by page_size\"\n        self.max_seq_length = max_seq_length\n        self.max_batch_size = max_batch_size\n        self.kernel_options = kernel_options\n        self.prefill_length_limit = prefill_length_limit  # NOTE: control the peak memory usage of prefill\n\n        for layer in self.model.model.layers:\n            layer.self_attn.kv_cache = PagedKVCache(\n                self.page_table,\n                n_heads=self.model.model.config.num_key_value_heads,\n                head_dim=self.model.model.config.head_dim,\n                dtype=self.model.dtype,\n            ).to(self.device, non_blocking=True)\n\n        self.cudagraph_captured = False\n\n        self.input_pos = torch.zeros(self.max_batch_size, dtype=torch.int32, pin_memory=True).to(self.device, non_blocking=True)\n        self.block_mask = self.page_table.create_causal_blockmask(B=self.max_batch_size, L=self.max_seq_length)\n\n    def _prefill_sequences(\n        self, input_ids: torch.Tensor, input_pos: torch.Tensor, batch_idx_tensor: torch.Tensor, logits_to_keep: torch.Tensor\n    ) -> torch.Tensor:\n        # 1. no cuda graph\n        # 2. construct block mask and apply it in logical space\n        # 3. only write to kv cache, no read\n\n        # NOTE: for batch/packed prefill, we need to pass batch_idx_tensor as [1, L]\n        # input_ids is [1, L], concatenated from all sequences\n        # batch_idx_tensor is [1, L]\n        # position_ids is [1, L]\n        # logits_to_keep is [num_sequences] instead of [1]\n\n        ## padding: if there's padding\n        # input_ids should be padded with any valid token id\n        # input_pos should be padded with 0\n        # batch_idx_tensor should be padded with 0 # reserved in page table\n\n        assert input_ids.shape[0] == 1, \"input_ids must be [1, L]\"\n        assert input_pos.shape == input_ids.shape, f\"input_pos must be [1, L], got {input_pos.shape=}, {input_ids.shape=}\"\n        assert batch_idx_tensor.shape == input_ids.shape, f\"batch_idx_tensor must be [1, L], got {batch_idx_tensor.shape=}, {input_ids.shape=}\"\n\n        mask = self.page_table.create_prefill_blockmask_no_paging(batch_idx_tensor)\n        outputs = self.model.model(\n            input_ids=input_ids,\n            position_ids=input_pos + 1,  # NOTE: gemma2 uses 1-based position ids\n            # logits_to_keep=logits_to_keep,\n            flex_attn_block_mask=mask,\n            flex_attn_input_pos=input_pos,\n            flex_attn_batch_idx=batch_idx_tensor,\n            flex_attn_kernel_options=self.kernel_options\n            | {\"FORCE_USE_FLEX_ATTENTION\": True},  # NOTE: force torch compile to not use flash decoding code path\n        )\n        return self.model.lm_head(outputs.last_hidden_state[:, logits_to_keep, :])\n\n        \"\"\"\n        outputs = self.model(\n            input_ids=input_ids,\n            position_ids=input_pos + 1, # NOTE: gemma2 uses 1-based position ids\n            logits_to_keep=logits_to_keep,\n            flex_attn_block_mask=mask,\n            flex_attn_input_pos=input_pos,\n            flex_attn_batch_idx=batch_idx_tensor,\n            flex_attn_kernel_options=self.kernel_options | {'FORCE_USE_FLEX_ATTENTION': True}, # NOTE: force torch compile to not use flash decoding code path\n        )\n        return outputs.logits\n        \"\"\"\n\n    def prefill_sequences(self, sequences: list[Sequence]) -> torch.Tensor:\n        input_ids = torch.cat([seq.total_token_ids for seq in sequences], dim=0)\n        input_pos = torch.cat([torch.arange(seq.total_length, dtype=torch.long) for seq in sequences], dim=0)\n        batch_idx_tensor = torch.cat([torch.ones(seq.total_length, dtype=torch.long) * seq.batch_idx for seq in sequences], dim=0)\n        input_lengths = torch.tensor([seq.total_length for seq in sequences], dtype=torch.int32).to(self.device, non_blocking=True)\n        logits_to_keep = input_lengths.cumsum(dim=0) - 1\n\n        num_pad = 128 - input_ids.shape[0] % 128\n        if num_pad > 0:\n            input_ids = F.pad(input_ids.view(-1), (0, num_pad), mode=\"constant\", value=0)\n            input_pos = F.pad(input_pos.view(-1), (0, num_pad), mode=\"constant\", value=0)\n            batch_idx_tensor = F.pad(batch_idx_tensor.view(-1), (0, num_pad), mode=\"constant\", value=0)\n            # logits_to_keep is not padded, it should have shape [num_sequences]\n\n        input_ids = input_ids.view(1, -1).to(self.device, non_blocking=True)\n        input_pos = input_pos.view(1, -1).to(self.device, non_blocking=True)\n        batch_idx_tensor = batch_idx_tensor.view(1, -1).to(self.device, non_blocking=True)\n        logits_to_keep = logits_to_keep.view(-1).to(self.device, non_blocking=True)\n\n        logits = self._prefill_sequences(input_ids, input_pos, batch_idx_tensor, logits_to_keep)\n        return logits\n\n    def get_decoding_block_mask(self, batch_idx: torch.Tensor):\n        \"\"\"\n        Args:\n            batch_idx: [B]\n        Returns:\n            block_mask: [B, H, ROWS=1, MAX_BLOCKS_IN_COL]\n            input_pos: [B]\n\n        This function slices the\n            full block mask self.block_mask:  [max_batch_size, H, MAX_BLOCKS_IN_ROW, MAX_BLOCKS_IN_COL]\n            using self.input_pos: [max_batch_size]\n            and batch_idx: [B]\n        \"\"\"\n\n        # NOTE: this function is entirely in logical space\n        def causal_offset(off: torch.Tensor):\n            def offset(b, h, q_idx, kv_idx):\n                return q_idx + off[b] >= kv_idx\n\n            return offset\n\n        block_mask = self.block_mask\n        input_pos = self.input_pos[batch_idx]\n        # batch_idx: [B], input_pos: [B]\n        assert batch_idx.ndim == 1, \"batch_idx must be 1D\"\n        assert input_pos.ndim == 1, \"input_pos must be 1D\"\n        (B,) = batch_idx.shape\n        input_block_idx = input_pos // block_mask.BLOCK_SIZE[0]  # [B]\n        kv_num_blocks = block_mask.kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1)\n        kv_indices = block_mask.kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1)\n        full_kv_num_blocks, full_kv_indices = None, None\n        if block_mask.full_kv_num_blocks is not None:\n            full_kv_num_blocks = block_mask.full_kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1)  # noqa\n            full_kv_indices = block_mask.full_kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1)  # noqa\n        seq_length = (1, block_mask.seq_lengths[1])\n        mask = BlockMask.from_kv_blocks(\n            kv_num_blocks,\n            kv_indices,\n            full_kv_num_blocks,\n            full_kv_indices,\n            BLOCK_SIZE=block_mask.BLOCK_SIZE,\n            mask_mod=causal_offset(input_pos),\n            seq_lengths=seq_length,\n        )\n        return mask, input_pos\n\n    def _decode_step(self, batch_idx: torch.Tensor, input_ids: torch.Tensor):\n        B = input_ids.shape[0]\n        mask, input_pos = self.get_decoding_block_mask(batch_idx)\n        mask = self.page_table.convert_logical_block_mask(mask, batch_idx)\n        outputs = self.model(\n            input_ids=input_ids.view(B, 1),\n            position_ids=(input_pos + 1).view(B, 1),  # NOTE: position_ids is needed for decoding. For Gemma2, it's 1-based\n            flex_attn_block_mask=mask,\n            flex_attn_input_pos=input_pos.view(B, 1),\n            flex_attn_batch_idx=batch_idx.view(-1),\n            flex_attn_kernel_options=self.kernel_options,\n        )\n        return outputs.logits\n\n    def decode_step(self, batch_idx: torch.Tensor, input_ids: torch.Tensor, input_pos: torch.Tensor):\n        assert input_ids.ndim == 1, \"input_ids must be 1D\"\n        assert batch_idx.ndim == 1, \"batch_idx must be 1D\"\n        assert input_ids.shape[0] == batch_idx.shape[0], \"input_ids and batch_idx must have the same length\"\n        self.input_pos.zero_()\n        self.input_pos[batch_idx] = input_pos\n        if not self.cudagraph_captured:\n            return self._decode_step(batch_idx, input_ids)\n        else:\n            bs = input_ids.size(0)\n            graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]\n            graph_vars = self.graph_vars\n            for k, v in graph_vars.items():\n                if k != \"outputs\":\n                    v.zero_()  # batch idx is 0 for undefined batch idx, so we never overwrite any kv-cache\n            graph_vars[\"input_ids\"][:bs] = input_ids\n            graph_vars[\"batch_idx\"][:bs] = batch_idx\n            graph.replay()\n            replayed = graph_vars[\"outputs\"][:bs]\n            return replayed\n\n    def _check_done(self, sequences: list[Sequence]):\n        for seq in sequences:\n            is_eos = seq.last_token_id == self.eos_token_id\n            is_max_len = seq.input_length + seq.output_length >= self.max_seq_length\n            is_max_new = seq.output_length == seq.params.max_new_tokens\n            if is_eos or is_max_len or is_max_new:\n                seq.finished = True\n                self.done.append(seq)\n                self.page_table.erase(seq.batch_idx)\n        return [seq for seq in sequences if not seq.finished]\n\n    def run_one_step(self):\n        # try to prefill a new sequence, if we can schedule it\n        batch = []\n        prefill_length_sum = 0\n        # NOTE: for each sequence, only allocate & reserve just enough pages for the current sequence length\n        # We do not reserve pages for the future tokens. This allows maximizing the batch size of decoding.\n        # We support preemption in case we run out of pages during decoding\n        while (\n            self.waiting\n            and self.page_table.can_reserve(self.waiting[0].total_length)  # + self.waiting[0].params.max_new_tokens)\n            and (not batch or self.prefill_length_limit == -1 or prefill_length_sum + self.waiting[0].total_length < self.prefill_length_limit)\n        ):\n            prefill_length_sum += self.waiting[0].total_length\n\n            seq = self.waiting.popleft()\n            batch_idx_int = self.page_table.allocate()\n            batch_idx_tensor = torch.tensor([batch_idx_int], device=self.device, dtype=torch.long)\n            self.page_table.reserve(batch_idx_int=batch_idx_int, batch_idx=batch_idx_tensor, seq_len=seq.total_length)  # + seq.params.max_new_tokens\n            seq.batch_idx = batch_idx_int\n            batch.append(seq)\n        if batch:\n            logits_1LV = self.prefill_sequences(batch)\n            next_token, logits, probs = sample(logits_1LV[0, :, :], to_cpu=True)\n            for b in range(len(batch)):\n                batch[b].add_next_token(next_token[b], logits[b], probs[b])\n            self.running.extend(self._check_done(batch))\n            return \"prefill\"\n        # reserve new block for running sequences if needed\n        # if a new block is needed, but there's no space, preempt the newest sequence\n        # running: [oldest, ... , newest]  waiting: [oldest, ... , newest]\n        while self.running:\n            seq = self.running.popleft()\n            if self.page_table.capacity[seq.batch_idx] >= seq.total_length:\n                # no need to reserve new pages\n                batch.append(seq)\n            elif self.page_table.can_reserve(seq.total_length, batch_idx_int=seq.batch_idx):\n                # reserve new pages\n                self.page_table.reserve(\n                    batch_idx_int=seq.batch_idx,\n                    batch_idx=torch.tensor([seq.batch_idx], device=self.device, dtype=torch.long),\n                    seq_len=seq.total_length,\n                )\n                batch.append(seq)\n            else:\n                # no space to run this sequence, preempt the newest sequence\n                self.running.appendleft(seq)  # first put this sequence back\n                newest = self.running.pop()  # then pop the newest sequence\n                self.waiting.appendleft(newest)\n                self.page_table.erase(newest.batch_idx)\n\n        B = len(batch)\n        # now we do decoding\n        batch_idx = torch.tensor([seq.batch_idx for seq in batch], dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True)\n        input_ids = torch.tensor([seq.last_token_id for seq in batch], dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True)\n        input_pos = torch.tensor([seq.total_length - 1 for seq in batch], dtype=torch.int32, pin_memory=True).to(self.device, non_blocking=True)\n        self.counts.append(B)\n        logits_BLV = self.decode_step(batch_idx, input_ids, input_pos)\n        next_token, logits, probs = sample(logits_BLV[:, -1, :], to_cpu=True)\n\n        for i in range(B):\n            batch[i].add_next_token(next_token[i], logits[i], probs[i])\n        self.running = deque(self._check_done(batch))\n        return \"decode\"\n\n    def tokenize(self, sequences: list[Sequence]):\n        self.tokenizer.padding_side = \"right\"\n        for seq in sequences:\n            seq.input_ids = self.tokenizer([seq.text], return_tensors=\"pt\")[\"input_ids\"].squeeze(0)\n            seq.input_length = seq.input_ids.shape[0]\n\n    @torch.inference_mode()\n    def generate(\n        self,\n        sequences: list[Sequence],\n        use_tqdm=False,\n        profiler=None,\n        greedy=False,\n        sampling_params: SamplingParams | list[SamplingParams] | None = None,\n        capture_cudagraph=False,\n        save_metrics_csv=False,\n        print_stats=False,\n    ):\n        self.counts = []\n        self.metrics_data = {\"step\": [], \"requests_running\": [], \"requests_waiting\": [], \"step_type\": []}\n        # preprocess the sequences\n        self.tokenize(sequences)\n        self.waiting = deque(sequences)\n        self.running = deque()\n        self.done = deque()\n        process_sampling_params(sequences, sampling_params)\n\n        if capture_cudagraph and not self.cudagraph_captured:\n            self.capture_decode_cudagraph()\n            self.cudagraph_captured = True\n\n        total_sequences = len(self.waiting)\n        times = []\n        step_count = 0\n        with tqdm(total=total_sequences, disable=not use_tqdm, desc=\"Generating\") as pbar:\n            prev_done = 0\n            while self.waiting or self.running:\n                step_count += 1\n                # Track metrics before step\n                self.metrics_data[\"step\"].append(step_count)\n                self.metrics_data[\"requests_running\"].append(len(self.running))\n                self.metrics_data[\"requests_waiting\"].append(len(self.waiting))\n\n                time_start = time.perf_counter()\n                step_type = self.run_one_step()\n                time_end = time.perf_counter()\n\n                # Track step type\n                self.metrics_data[\"step_type\"].append(step_type)\n                times.append({\"step_type\": step_type, \"time\": time_end - time_start})\n                if profiler:\n                    profiler.step()\n                # Update progress bar based on newly completed sequences\n                curr_done = len(self.done)\n                if curr_done > prev_done:\n                    pbar.update(curr_done - prev_done)\n                    prev_done = curr_done\n        if print_stats:\n            self.print_time_stats(times)\n\n        # Save metrics to CSV if requested\n        if save_metrics_csv:\n            import pandas as pd\n\n            df = pd.DataFrame(self.metrics_data)\n            df.to_csv(\"flex_nano_vllm_metrics.csv\", index=False)\n            print(\"Metrics saved to 'flex_nano_vllm_metrics.csv'\")\n\n    def capture_decode_cudagraph(self):\n        \"\"\"\n        capture cudagraph for decoding\n        \"\"\"\n        max_bs = self.max_batch_size\n        input_ids = torch.zeros(max_bs, dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True)\n        batch_idx = torch.arange(max_bs, dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True)\n        # NOTE: here we use logits as the final output, but we can consider using last hidden state as the output\n        outputs = torch.zeros((max_bs, 1, self.model.model.config.vocab_size), pin_memory=True).to(self.device, non_blocking=True)\n        self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))  # 8 vs 16 doesn't make much difference\n        self.graphs = {}\n        self.graph_pool = None\n\n        for bs in reversed(self.graph_bs):\n            print(f\"capturing cudagraph for {bs} sequences\")\n            torch.cuda.synchronize()\n            graph = torch.cuda.CUDAGraph()\n            # warmup\n            outputs[:bs] = self._decode_step(batch_idx[:bs], input_ids[:bs])  # warmup\n            # capture\n            with torch.cuda.graph(graph, self.graph_pool):\n                outputs[:bs] = self._decode_step(batch_idx[:bs], input_ids[:bs])  # capture\n            if self.graph_pool is None:\n                self.graph_pool = graph.pool()\n            self.graphs[bs] = graph\n            torch.cuda.synchronize()\n\n        self.graph_vars = dict(\n            input_ids=input_ids,\n            batch_idx=batch_idx,\n            outputs=outputs,\n        )\n        # in our code, the page table tensors are modified in-place, so we don't need to put them in graph vars\n\n    def print_time_stats(self, times):\n        stats = {}\n        for step in [\"decode\", \"prefill\"]:\n            step_times = [t[\"time\"] for t in times if t[\"step_type\"] == step]\n            stats[step] = {\n                \"count\": len(step_times),\n                \"total\": sum(step_times),\n                \"mean\": sum(step_times) / len(step_times) if step_times else 0,\n                \"min\": min(step_times) if step_times else 0,\n                \"max\": max(step_times) if step_times else 0,\n            }\n\n        print(\"\\nTime statistics by step type:\")\n        for step, metrics in stats.items():\n            print(f\"\\n{step}:\")\n            print(f\"  Count: {metrics['count']}\")\n            print(f\"  Total: {metrics['total']:.4f}s\")\n            print(f\"  Mean:  {metrics['mean']:.4f}s\")\n            print(f\"  Min:   {metrics['min']:.4f}s\")\n            print(f\"  Max:   {metrics['max']:.4f}s\")\n        print(f\"\\nTotal time: {sum(t['time'] for t in times):.4f}s\")\n"
  },
  {
    "path": "flex_nano_vllm/modeling_gemma2.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_gemma2.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Callable, Optional, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache\nfrom transformers.generation import GenerationMixin\nfrom transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_layers import GradientCheckpointingLayer\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.processing_utils import Unpack\nfrom transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging\nfrom transformers.utils.deprecation import deprecate_kwarg\nfrom transformers.utils.generic import check_model_inputs\nfrom transformers.models.gemma2.configuration_gemma2 import Gemma2Config\n\nfrom torch.nn.attention.flex_attention import flex_attention\n\nflex_attention = torch.compile(flex_attention, fullgraph=True)\n\nlogger = logging.get_logger(__name__)\n\n\nclass Gemma2RMSNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.zeros(dim))\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        output = self._norm(x.float())\n        # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)\n        # See https://github.com/huggingface/transformers/pull/29402\n        output = output * (1.0 + self.weight.float())\n        return output.type_as(x)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.eps}\"\n\nclass Gemma2MLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_activation]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    dropout: float = 0.0,\n    scaling: Optional[float] = None,\n    softcap: Optional[float] = None,\n    **kwargs,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    if scaling is None:\n        scaling = module.head_dim**-0.5\n\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n\n    if softcap is not None:\n        attn_weights = attn_weights / softcap\n        attn_weights = torch.tanh(attn_weights)\n        attn_weights = attn_weights * softcap\n    if attention_mask is not None:  # no matter the length, we just slice it\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    # upcast attention to fp32\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n    return attn_output, attn_weights\n\n\nclass Gemma2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Gemma2Config, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.scaling = config.query_pre_attn_scalar**-0.5\n        self.attention_dropout = self.config.attention_dropout\n        self.is_causal = True\n\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.o_proj = nn.Linear(\n            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias\n        )\n        self.attn_logit_softcapping = self.config.attn_logit_softcapping\n        self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == \"sliding_attention\" else None\n        self.kv_cache = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        flex_attn_block_mask = None,\n        flex_attn_input_pos = None,\n        flex_attn_batch_idx = None,\n        flex_attn_kernel_options = {'BLOCK_M': 32, 'BLOCK_N': 32, 'BLOCK_M1': 32, 'BLOCK_M2': 32, 'BLOCK_N1': 32, 'BLOCK_N2': 32, },\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        # NOTE: this does not cover the sliding window case, but in my current usage, the sequence length does not exceed 4096\n        def soft_cap(score, b, h, q_idx, kv_idx):\n            score = score / self.attn_logit_softcapping\n            score = torch.tanh(score)\n            score = score * self.attn_logit_softcapping\n            return score\n\n        if self.kv_cache is not None and flex_attn_input_pos is not None:\n            key_states, value_states= self.kv_cache.update(flex_attn_input_pos, key_states, value_states, flex_attn_batch_idx)\n\n\n        attn_output = flex_attention(\n            query_states,\n            key_states,\n            value_states,\n            #dropout=self.attention_dropout if self.training else 0.0,\n            scale=self.scaling,\n            block_mask=flex_attn_block_mask,\n            score_mod=soft_cap,\n            enable_gqa=True,\n            kernel_options=flex_attn_kernel_options,\n        )\n        attn_weights = None\n        attn_output = attn_output.transpose(1, 2) # (B, H, N, E) -> (B, N, H, E)\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass Gemma2DecoderLayer(GradientCheckpointingLayer):\n    def __init__(self, config: Gemma2Config, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.config = config\n        self.attention_type = config.layer_types[layer_idx]\n        self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)\n        self.mlp = Gemma2MLP(config)\n        self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    @deprecate_kwarg(\"last_cache_position\", version=\"4.53.0\")\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            position_embeddings=position_embeddings,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            **kwargs,\n        )\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.pre_feedforward_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = self.post_feedforward_layernorm(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        return outputs\n\n\nclass Gemma2RotaryEmbedding(nn.Module):\n    def __init__(self, config: Gemma2Config, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and isinstance(config.rope_scaling, dict):\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    @torch.no_grad()\n    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\n@auto_docstring\nclass Gemma2PreTrainedModel(PreTrainedModel):\n    config: Gemma2Config\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Gemma2DecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n\n    _supports_static_cache = True\n    _supports_attention_backend = True\n    _can_record_outputs = {\n        \"hidden_states\": Gemma2DecoderLayer,\n        \"attentions\": Gemma2Attention,\n    }\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, Gemma2RMSNorm):\n            module.weight.data.fill_(1.0)\n\n\n@auto_docstring\nclass Gemma2Model(Gemma2PreTrainedModel):\n    def __init__(self, config: Gemma2Config):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Gemma2RotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @check_model_inputs\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[TransformersKwargs],\n    ) -> BaseModelOutputWithPast:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if use_cache and past_key_values is None and not self.training:\n            past_key_values = DynamicCache()\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        # embed positions\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # normalized\n        # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5\n        # See https://github.com/huggingface/transformers/pull/29402\n        normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)\n        hidden_states = hidden_states * normalizer\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n\n        for decoder_layer in self.layers[: self.config.num_hidden_layers]:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            layer_outputs = decoder_layer(\n                hidden_states,\n                position_embeddings=position_embeddings,\n                attention_mask=None,\n                position_ids=position_ids,\n                past_key_value=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n                **kwargs,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\n@auto_docstring\nclass Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Gemma2Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs,\n    ) -> CausalLMOutputWithPast:\n        r\"\"\"\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Gemma2ForCausalLM\n\n        >>> model = Gemma2ForCausalLM.from_pretrained(\"google/gemma-2-9b\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2-9b\")\n\n        >>> prompt = \"What is your favorite condiment?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"What is your favorite condiment?\"\n        ```\"\"\"\n\n        if self.training and self.config._attn_implementation != \"eager\":\n            logger.warning_once(\n                \"It is strongly recommended to train Gemma2 models with the `eager` attention implementation \"\n                f\"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.\"\n            )\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs: BaseModelOutputWithPast = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = outputs.last_hidden_state\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n        if self.config.final_logit_softcapping is not None:\n            logits = logits / self.config.final_logit_softcapping\n            logits = torch.tanh(logits)\n            logits = logits * self.config.final_logit_softcapping\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n__all__ = [\n    \"Gemma2ForCausalLM\",\n    \"Gemma2Model\",\n    \"Gemma2PreTrainedModel\",\n]\n"
  },
  {
    "path": "flex_nano_vllm/paged_attention.py",
    "content": "# Adapted from attention-gym\n# Original source: https://github.com/pytorch-labs/attention-gym\n# License: BSD 3-Clause (see THIRD_PARTY_LICENSES.md)\n# Copyright (c) 2023, Driss Guessous\n\n# the original implementation has some bugs and has some feature that lives outside of the PageTable class\n\nfrom typing import Optional\nimport torch\nfrom torch import Tensor\nfrom torch.nn.attention.flex_attention import (\n    _identity,\n    _mask_mod_signature,\n    _score_mod_signature,\n    BlockMask,\n    noop_mask,\n    create_block_mask,\n)\n\ncreate_block_mask = torch.compile(create_block_mask)\n\n\ndef _cdiv(x: int | float | torch.Tensor, multiple: int | float | torch.Tensor):\n    return (x + multiple - 1) // multiple\n\n\nclass PagedKVCache(torch.nn.Module):\n    def __init__(self, page_table, n_heads, head_dim, dtype):\n        super().__init__()\n        cache_shape = (1, n_heads, page_table.n_pages * page_table.page_size, head_dim)\n        self.register_buffer(\"k_cache\", torch.zeros(cache_shape, dtype=dtype))\n        self.register_buffer(\"v_cache\", torch.zeros(cache_shape, dtype=dtype))\n\n        self.page_table = page_table\n\n    def update(self, input_pos, k_val, v_val, batch_idx=None):\n        assert batch_idx is not None, \"batch_idx is required for paged kv cache, are you using non-paged attention?\"\n\n        if batch_idx.ndim == 1:\n            # batch_idx should be [B] (decode)\n            return self.page_table.assign(batch_idx, input_pos, k_val, v_val, self.k_cache, self.v_cache)\n        else:\n            assert batch_idx.ndim == 2, \"batch_idx must be 1D or 2D\"\n            # batch_idx should be [1, L] (batch prefill)\n            return self.page_table.assign_prefill_no_paging(batch_idx, input_pos, k_val, v_val, self.k_cache, self.v_cache)\n\n\nclass PageTable:\n    \"\"\"\n    PageTable is a modified version of PagedAttention from attention-gym.\n\n    PageTable improves it by:\n    - maintaining a cpu copy of the page table, to avoid device-to-host transfers\n    - support batch prefill\n    - fix the bug in the original code in mask_mod and score_mod by mapping physical batch index to logical batch index\n    - subsuming the free_batch_idx into the page table, so we don't need to maintain it separately\n    \"\"\"\n\n    def __init__(\n        self,\n        n_pages: int,\n        page_size: int,\n        max_batch_size: int,\n        device: str = \"cuda\",\n    ):\n        self.n_pages = n_pages\n        self.page_size = page_size\n        self.max_batch_size = max_batch_size\n        self.device = device\n\n        # page table: [logical_batch_idx, logical_block_idx] -> physical_page_idx\n        self.page_table = -torch.ones((max_batch_size, self.n_pages), dtype=torch.int64, device=device)\n        self.page_table[0, :] = 0  # page 0 is reserved for simpler code in assign_prefill_no_paging\n        self.page_table_cpu = [[] for _ in range(max_batch_size)]\n\n        self.capacity = [0 for _ in range(max_batch_size)]  # capacity: batch_idx -> number of pages allocated * page size\n        self.free_pages = list(reversed(range(1, n_pages)))  # page 0 is reserved for simpler code in assign_prefill_no_paging\n        self.free_batch_idx = list(reversed(range(1, max_batch_size)))  # batch_idx 0 is reserved for no-op\n\n        # [logical_batch_idx, physical_page_idx] -> logical_page_idx\n        self.physical_to_logical = -torch.ones((max_batch_size, n_pages), dtype=torch.int64, device=device)\n\n    def can_reserve(self, size: int, batch_idx_int: int | None = None) -> bool:\n        \"\"\"check if we can reserve new pages for an existing request or a new request, without gpu operations\"\"\"\n        if batch_idx_int is None:\n            # check if we can schedule a new request\n            return self.pages_available * self.page_size >= size and len(self.free_batch_idx) > 0\n        else:\n            # check if we can reserve new pages for an existing request\n            return self.reserve(batch_idx_int, None, size, dry_run=True)\n\n    def allocate(self) -> int:\n        \"\"\"allocate a new batch\"\"\"\n        batch_idx = self.free_batch_idx.pop()\n\n        self.capacity[batch_idx] = 0\n        self.physical_to_logical[batch_idx, :] = -1\n        self.page_table[batch_idx, :] = -1\n        return batch_idx\n\n    @property\n    def pages_available(self) -> int:\n        return len(self.free_pages)\n\n    def reserve(self, batch_idx_int: int, batch_idx: torch.Tensor, seq_len: int, dry_run: bool = False) -> bool:\n        \"\"\"\n        Requests the capacity of a given batch to be at least enough to\n        hold `seq_len` elements.\n\n        Args:\n            batch_idx_int (int): batch index to be reserved;\n            batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`.\n            seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`.\n\n        Returns:\n            bool: True if the reservation was successful, False if the reservation was not successful (no space, and in this case, no update is done)\n        \"\"\"\n\n        if seq_len <= self.capacity[batch_idx_int]:\n            return True\n\n        num_pages_to_allocate = _cdiv(seq_len - self.capacity[batch_idx_int], self.page_size)\n\n        can_allocate = num_pages_to_allocate <= self.pages_available\n        if dry_run:\n            return can_allocate\n\n        if not can_allocate:\n            raise RuntimeError(\n                f\"Cannot reserve {num_pages_to_allocate} pages for a sequence of length {seq_len} \"\n                f\"in batch {batch_idx_int}. Only {self.pages_available} pages available. \"\n                f\"Current capacity is {self.capacity[batch_idx_int]} tokens.\"\n            )\n\n        start_page_idx = self.capacity[batch_idx_int] // self.page_size\n        end_page_idx = start_page_idx + num_pages_to_allocate\n\n        # find empty physical pages\n        allocated_pages_list = self.free_pages[-num_pages_to_allocate:]\n        allocated_pages = torch.tensor(allocated_pages_list, device=self.device)\n        # update page table\n        self.page_table[batch_idx, start_page_idx:end_page_idx] = allocated_pages\n\n        # update metadata\n        self.physical_to_logical[batch_idx, allocated_pages] = torch.arange(\n            start_page_idx,\n            end_page_idx,\n            device=self.device,\n        )\n        # update cpu side metadata\n        self.page_table_cpu[batch_idx_int] += allocated_pages_list\n        self.free_pages = self.free_pages[:-num_pages_to_allocate]\n        self.capacity[batch_idx_int] += num_pages_to_allocate * self.page_size\n        return True\n\n    def erase(self, batch_idx: int) -> None:\n        \"\"\"\n        Removes a single batch from paged attention.\n\n        Args:\n            batch_idx (int): batch index to be removed;\n        \"\"\"\n        # NOTE: the GPU side data will only be reset/overwritten when we allocate it for a new batch\n        self.free_batch_idx.append(batch_idx)\n        allocated_pages_cpu = self.page_table_cpu[batch_idx]\n        self.free_pages.extend(reversed(allocated_pages_cpu))\n        self.page_table_cpu[batch_idx] = []\n\n    def assign(\n        self,\n        batch_idx: torch.Tensor,\n        input_pos: torch.Tensor,\n        k_val: torch.Tensor,\n        v_val: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n    ) -> None:\n        \"\"\"\n        Assigns new contents `val` to the storage `cache` at the location\n        `batch_idx` and `input_pos`.\n\n        Args:\n            batch_idx (Tensor): batch index; shape :math:`(B)`.\n            input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`.\n            val (Tensor): value to be assigned; shape :math:`(B, H, S, D)`\n            cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)`\n        \"\"\"\n        if k_val.requires_grad:\n            raise RuntimeError(\"val must not require gradient\")\n\n        B, H, S, K_D = k_val.shape\n        _, H_cache, MAX_S, D_cache = k_cache.shape\n        assert H_cache == H, \"number of heads must match\"\n        assert MAX_S >= S, \"cache must have enough space\"\n        assert D_cache == K_D, \"hidden dim must match\"\n        assert input_pos.shape == (B, S), \"input_pos must have the same shape as val\"\n        assert batch_idx.shape == (B,), \"batch_idx must have one dimension only\"\n\n        V_D = v_val.shape[3]\n        if B != batch_idx.shape[0]:\n            raise RuntimeError(f\"Expect val and batch_idx have the same batch size but got B={B} and B={batch_idx.shape[0]}.\")\n        if H != k_cache.shape[1]:\n            raise RuntimeError(f\"Expect val and cache has the same number of heads but got H={H} and H={k_cache.shape[1]}.\")\n        if S != input_pos.shape[1]:\n            raise RuntimeError(f\"Expect val and input_pos has the same length but got S={S} and S={input_pos.shape[0]}.\")\n        if K_D != k_cache.shape[3]:\n            raise RuntimeError(f\"Expect k_val and k_cache has the same hidden dim but got D={K_D} and D={k_cache.shape[3]}.\")\n        if V_D != v_cache.shape[3]:\n            raise RuntimeError(f\"Expect v_val and v_cache has the same hidden dim but got D={V_D} and D={v_cache.shape[3]}.\")\n\n        # find address\n        logical_block_idx = input_pos // self.page_size  # [B, S]\n        logical_block_offset = input_pos % self.page_size  # [B, S]\n\n        # NOTE: this code path is only used for decoding. For batch prefill, use assign_prefill_no_paging() instead\n        physical_block_idx = torch.gather(self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)).to(torch.int32)  # [B, S]\n\n        addr = (physical_block_idx * self.page_size + logical_block_offset).view(-1)  # [B*S]\n\n        k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D)\n        v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D)\n\n        k_cache[:, :, addr, :] = k_val\n        v_cache[:, :, addr, :] = v_val\n\n        return k_cache, v_cache\n\n    def convert_logical_block_mask(\n        self,\n        block_mask: BlockMask,\n        batch_idx: Optional[torch.Tensor] = None,\n    ) -> BlockMask:\n        \"\"\"\n        Converts a logical block mask by mapping its logical kv indices to the corresponding\n        physical kv indices.\n\n        Args:\n            block_mask (BlockMask): logical block mask;\n                kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`.\n            batch_idx (Tensor): batch index corresponding to the block_mask\n                batch dimension. This provides flexibility to convert a\n                block mask with smaller batch size than the page table;\n                shape :math:`(B)`.\n        \"\"\"\n        B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape\n\n        if block_mask.BLOCK_SIZE[1] != self.page_size:\n            raise RuntimeError(\n                f\"Expect block_mask has the same column block size as page_sizebut got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}\"\n            )\n\n        device = block_mask.kv_num_blocks.device\n\n        if batch_idx is None:\n            batch_idx = torch.arange(B, device=device)\n\n        assert batch_idx.ndim == 1, \"batch_idx must be a 1D tensor\"\n        assert batch_idx.shape[0] == B, \"batch_idx must have the same shape as block_mask\"\n        assert B <= self.max_batch_size, \"batch_idx must be less than or equal to max_batch_size\"\n\n        page_table = self.page_table[batch_idx]\n\n        def transform(num_blocks, indices):\n            \"\"\"\n            transform the block mask from [B, H, num_q_blocks, num_logical_kv_blocks]\n            to [B, H, num_q_blocks, num_physical_kv_blocks]\n\n            kv_num_blocks: [B, H, num_q_blocks] -> unchanged\n            kv_indices: [B, H, num_q_blocks, num_logical_kv_blocks] -> [B, H, num_q_blocks, num_physical_kv_blocks]\n            \"\"\"\n            if num_blocks is None:\n                return None, None\n            new_kv_num_blocks = num_blocks.clone()\n            new_kv_indices = torch.zeros((B, H, ROWS, self.n_pages), dtype=torch.int32, device=device)\n            new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = (\n                torch.gather(page_table, 1, indices.view(B, -1).to(torch.int64)).view(block_mask.kv_indices.shape).to(torch.int32)\n            )\n            return new_kv_num_blocks, new_kv_indices\n\n        new_kv_num_blocks, new_kv_indices = transform(block_mask.kv_num_blocks, block_mask.kv_indices)\n        new_full_kv_num_blocks, new_full_kv_indices = transform(block_mask.full_kv_num_blocks, block_mask.full_kv_indices)\n\n        new_mask_mod = self.get_mask_mod(block_mask.mask_mod, batch_idx)\n\n        seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)\n        return BlockMask.from_kv_blocks(\n            new_kv_num_blocks,\n            new_kv_indices,\n            new_full_kv_num_blocks,\n            new_full_kv_indices,\n            block_mask.BLOCK_SIZE,\n            new_mask_mod,\n            seq_lengths=seq_lengths,\n        )\n\n    def get_logical_kv_idx(self, physical_batch_idx: torch.Tensor, physical_kv_idx: torch.Tensor, batch_idx: torch.Tensor):\n        logical_batch_idx = batch_idx[physical_batch_idx]\n        physical_kv_block = physical_kv_idx // self.page_size\n        physical_kv_offset = physical_kv_idx % self.page_size\n        logical_block_idx = self.physical_to_logical[logical_batch_idx, physical_kv_block]\n        logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset\n        is_valid = logical_block_idx >= 0\n        safe_logical_kv_idx = logical_kv_idx.clamp(min=0)\n        return is_valid, safe_logical_kv_idx\n\n    def get_mask_mod(self, mask_mod: Optional[_mask_mod_signature], batch_idx: torch.Tensor) -> _mask_mod_signature:\n        \"\"\"\n        Converts a mask_mod based on mapping from the physical block index to the logical\n        block index.\n\n        Args:\n            mask_mod (_mask_mod_signature): mask_mod based on the logical block index.\n        \"\"\"\n        if mask_mod is None:\n            mask_mod = noop_mask\n\n        def new_mask_mod(\n            b: torch.Tensor,\n            h: torch.Tensor,\n            q_idx: torch.Tensor,\n            physical_kv_idx: torch.Tensor,\n        ):\n            is_valid, safe_logical_kv_idx = self.get_logical_kv_idx(b, physical_kv_idx, batch_idx)\n            return torch.where(is_valid, mask_mod(b, h, q_idx, safe_logical_kv_idx), False)\n\n        return new_mask_mod\n\n    # NOTE: not used in the current codebase\n    def get_score_mod(self, score_mod: Optional[_score_mod_signature], batch_idx: torch.Tensor) -> _score_mod_signature:\n        \"\"\"\n        Converts a score_mod based on mapping from the physical block index to the logical\n        block index.\n\n        Args:\n            score_mod (_score_mod_signature): score_mod based on the logical block index.\n        \"\"\"\n        if score_mod is None:\n            score_mod = _identity\n\n        def new_score_mod(\n            score: torch.Tensor,\n            b: torch.Tensor,\n            h: torch.Tensor,\n            q_idx: torch.Tensor,\n            physical_kv_idx: torch.Tensor,\n        ):\n            is_valid, safe_logical_kv_idx = self.get_logical_kv_idx(b, physical_kv_idx, batch_idx)\n            return torch.where(\n                is_valid,\n                score_mod(score, b, h, q_idx, safe_logical_kv_idx),\n                float(\"-inf\"),\n            )\n\n        return new_score_mod\n\n    def create_causal_blockmask(self, B, L):\n        \"\"\"A minimal, unoptimized causal block mask creation function\"\"\"\n\n        def causal(b, h, q_idx, kv_idx):\n            return q_idx >= kv_idx\n\n        return create_block_mask(causal, B=B, H=None, Q_LEN=L, KV_LEN=L, BLOCK_SIZE=self.page_size, device=self.device)\n\n    def create_prefill_blockmask_no_paging(self, batch_idx: Tensor, BLOCK_SIZE: int = 128):\n        \"\"\"\n        there's no prefix sharing implemented, batch_idx is the document id, batch_idx is not guaranteed to be sorted\n        \"\"\"\n        assert batch_idx.ndim == 2, \"batch_idx must be a 2D tensor\"\n        assert batch_idx.shape[0] == 1, \"batch_idx must have batch size 1\"\n        L = batch_idx.shape[1]\n        docs = batch_idx.view(-1)\n\n        def document_causal(b, h, q_idx, kv_idx):\n            causal_mask = q_idx >= kv_idx\n            document_mask = docs[q_idx] == docs[kv_idx]\n            return causal_mask & document_mask\n\n        return create_block_mask(document_causal, B=1, H=None, Q_LEN=L, KV_LEN=L, BLOCK_SIZE=BLOCK_SIZE)\n\n    # we assign prefill to the cache, similar to assign(), except we don't return the k_cache, v_cache, we only return the k_val, v_val\n    def assign_prefill_no_paging(\n        self,\n        batch_idx: torch.Tensor,\n        input_pos: torch.Tensor,\n        k_val: torch.Tensor,\n        v_val: torch.Tensor,\n        k_cache: torch.Tensor,\n        v_cache: torch.Tensor,\n    ) -> None:\n        \"\"\"\n        assigns kv and returns the original kv\n\n        batch_idx: [1, L]\n        input_pos: [1, L]\n        k_val: [1, H, L, D]\n        v_val: [1, H, L, D]\n        k_cache: [1, H, MAX_S, D]\n        v_cache: [1, H, MAX_S, D]\n        \"\"\"\n\n        assert batch_idx.ndim == 2, \"batch_idx must be a 2D tensor\"\n        assert input_pos.ndim == 2, \"input_pos must be a 2D tensor\"\n        assert k_val.ndim == 4, \"k_val must be a 4D tensor\"\n        assert v_val.ndim == 4, \"v_val must be a 4D tensor\"\n        assert k_cache.ndim == 4, \"k_cache must be a 4D tensor\"\n        assert v_cache.ndim == 4, \"v_cache must be a 4D tensor\"\n        assert batch_idx.shape[0] == 1, \"batch_idx must have batch size 1\"\n\n        input_pos_block_idx = input_pos // self.page_size\n        input_pos_offset_in_block = input_pos % self.page_size\n        physical_kv_idx = self.page_table[batch_idx, input_pos_block_idx] * self.page_size + input_pos_offset_in_block\n        k_cache[:, :, physical_kv_idx.view(-1), :] = k_val\n        v_cache[:, :, physical_kv_idx.view(-1), :] = v_val\n\n        return k_val, v_val\n"
  },
  {
    "path": "plot_metrics.py",
    "content": "# /// script\n# requires-python = \">=3.12\"\n# dependencies = [\n#     \"pandas\",\n#     \"matplotlib\",\n# ]\n# ///\n\nimport pandas as pd\nimport matplotlib.pyplot as plt\n\n# Read the CSV files\nflex_nano_df = pd.read_csv('flex_nano_vllm_metrics.csv')\nvllm_df = pd.read_csv('vllm_metrics.csv')\n\n# Create figure with subplots\nfig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))\n\n# Plot 1: Running requests comparison\nax1.plot(flex_nano_df['step'], flex_nano_df['requests_running'], \n         label='Flex Nano VLLM', color='blue', linewidth=1.5)\nax1.plot(vllm_df['steps'], vllm_df['requests_running'], \n         label='VLLM', color='red', linewidth=1.5)\nax1.set_title('Running Requests Over Time')\nax1.set_xlabel('Step')\nax1.set_ylabel('Running Requests')\nax1.legend()\nax1.grid(True, alpha=0.3)\n\n# Plot 2: Waiting requests comparison\nax2.plot(flex_nano_df['step'], flex_nano_df['requests_waiting'], \n         label='Flex Nano VLLM', color='blue', linewidth=1.5)\nax2.plot(vllm_df['steps'], vllm_df['requests_waiting'], \n         label='VLLM', color='red', linewidth=1.5)\nax2.set_title('Waiting Requests Over Time')\nax2.set_xlabel('Step')\nax2.set_ylabel('Waiting Requests')\nax2.legend()\nax2.grid(True, alpha=0.3)\n\n# Plot 3: Flex Nano VLLM step types\nprefill_steps = flex_nano_df[flex_nano_df['step_type'] == 'prefill']\ndecode_steps = flex_nano_df[flex_nano_df['step_type'] == 'decode']\n\nax3.scatter(prefill_steps['step'], prefill_steps['requests_running'], \n           label='Prefill', alpha=0.6, s=10, color='green')\nax3.scatter(decode_steps['step'], decode_steps['requests_running'], \n           label='Decode', alpha=0.6, s=10, color='orange')\nax3.set_title('Flex Nano VLLM: Running Requests by Step Type')\nax3.set_xlabel('Step')\nax3.set_ylabel('Running Requests')\nax3.legend()\nax3.grid(True, alpha=0.3)\n\n# Plot 4: Total requests (running + waiting)\nflex_nano_total = flex_nano_df['requests_running'] + flex_nano_df['requests_waiting']\nvllm_total = vllm_df['requests_running'] + vllm_df['requests_waiting']\n\nax4.plot(flex_nano_df['step'], flex_nano_total, \n         label='Flex Nano VLLM Total', color='blue', linewidth=1.5)\nax4.plot(vllm_df['steps'], vllm_total, \n         label='VLLM Total', color='red', linewidth=1.5)\nax4.set_title('Total Requests (Running + Waiting)')\nax4.set_xlabel('Step')\nax4.set_ylabel('Total Requests')\nax4.legend()\nax4.grid(True, alpha=0.3)\n\nplt.tight_layout()\nplt.savefig('metrics_comparison.png', dpi=300, bbox_inches='tight')\nprint(\"Metrics comparison saved as 'metrics_comparison.png'\")\nplt.show()\n\n# Print some summary statistics\nprint(\"\\n=== Summary Statistics ===\")\nprint(f\"Flex Nano VLLM - Max running: {flex_nano_df['requests_running'].max()}\")\nprint(f\"VLLM - Max running: {vllm_df['requests_running'].max()}\")\nprint(f\"Flex Nano VLLM - Max waiting: {flex_nano_df['requests_waiting'].max()}\")\nprint(f\"VLLM - Max waiting: {vllm_df['requests_waiting'].max()}\")\nprint(f\"Flex Nano VLLM - Total steps: {len(flex_nano_df)}\")\nprint(f\"VLLM - Total steps: {len(vllm_df)}\")"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"flex-nano-vllm\"\nversion = \"0.1.0\"\ndescription = \"Flex-attention based nano-vllm implementation for fast PaliGemma inference\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\ndependencies = [\n    \"accelerate>=1.9.0\",\n    \"datasets>=3.0.0\",\n    \"hf-transfer>=0.1.9\",\n    \"matplotlib>=3.10.3\",\n    \"torch>=2.7.1\",\n    \"tqdm>=4.67.1\",\n    \"transformers>=4.53.2\",\n    \"triton>=3.3.1\",\n]\n\n[build-system]\nrequires = [\"setuptools>=61.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.setuptools]\npackages = [\"flex_nano_vllm\"]\n\n[dependency-groups]\ndev = [\n    \"rich>=14.1.0\",\n]\n\n[tool.uv.sources]\ntransformers = { git = \"https://github.com/huggingface/transformers\", rev = \"34133d0a\" }\n"
  },
  {
    "path": "visualize.py",
    "content": "# /// script\n# requires-python = \">=3.12\"\n# dependencies = [\n#     \"matplotlib\",\n#     \"numpy\",\n# ]\n# ///\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n# Data\nconfigs = ['50% GPU', '90% GPU', '90% GPU\\n(high batch)']\nvllm_output = [3020, 3772, 3840]\nflex_output = [2313, 3076, 3440]\n\n# Create figure\nfig, ax = plt.subplots(figsize=(12, 8))\n\nx = np.arange(len(configs))\nwidth = 0.35\n\nbars1 = ax.bar(x - width/2, vllm_output, width, label='vLLM v1', color='#1f77b4', alpha=0.8)\nbars2 = ax.bar(x + width/2, flex_output, width, label='flex-nano-vllm', color='#ff7f0e', alpha=0.8)\n\nax.set_title('Output Tokens/s Comparison by Configuration', fontsize=16, fontweight='bold', pad=20)\nax.set_ylabel('Tokens/s', fontsize=14)\nax.set_xlabel('GPU Memory Configuration', fontsize=14)\nax.set_xticks(x)\nax.set_xticklabels(configs, fontsize=12)\nax.legend(fontsize=12)\nax.grid(axis='y', alpha=0.3)\n\n# Add value labels\nfor bar in bars1:\n    height = bar.get_height()\n    ax.text(bar.get_x() + bar.get_width()/2., height + 50,\n            f'{int(height)}', ha='center', va='bottom', fontweight='bold', fontsize=11)\n\n# Add value labels with percentages for flex-nano-vllm\nfor i, bar in enumerate(bars2):\n    height = bar.get_height()\n    percentage = (flex_output[i] / vllm_output[i]) * 100\n    ax.text(bar.get_x() + bar.get_width()/2., height + 50,\n            f'{int(height)}\\n({percentage:.1f}%)', ha='center', va='bottom', fontweight='bold', fontsize=11)\n\nplt.tight_layout()\n\n# Save the plot\nplt.savefig('tokens_per_second_comparison.png', dpi=300, bbox_inches='tight')\nprint(\"Simple comparison saved as 'tokens_per_second_comparison.png'\")\n\nplt.show()"
  }
]