Repository: changjonathanc/flex-nano-vllm Branch: main Commit: 9bfd8ed60359 Files: 13 Total size: 94.1 KB Directory structure: gitextract_2vld2n29/ ├── .gitignore ├── LICENSE ├── README.md ├── THIRD_PARTY_LICENSES.md ├── benchmark.py ├── benchmark_vllm.py ├── flex_nano_vllm/ │ ├── __init__.py │ ├── inference.py │ ├── modeling_gemma2.py │ └── paged_attention.py ├── plot_metrics.py ├── pyproject.toml └── visualize.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ *.egg-info/ trace_dir/ *.csv ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2025 Jonathan Chang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # flex-nano-vllm FlexAttention based, minimal vllm-style inference engine for fast Gemma 2 inference. ## Introduction This project has no flash-attn dependency, no custom triton kernel. Everything is implemented with FlexAttention. The code is commented, the structure is flat. Read the accompanying write-up: [vLLM flex attention from scratch](https://jonathanc.net/blog/vllm-flex-attention-from-scratch). ## Code Structure ``` flex-nano-vllm/ ├── benchmark.py # Testing and benchmarking script. ├── benchmark_vllm.py # vLLM comparison benchmark (uses uv inline dependency to run vLLM). ├── visualize.py # Performance visualization script. └── flex_nano_vllm/ ├── inference.py # Main inference engine, uses paged attention. ├── modeling_gemma2.py # Gemma2 model implementation, copied from transformers. └── paged_attention.py # Paged attention implementation, including page table and paged kv cache. Based on attention-gym. ``` ## Quick Start ``` uv sync # run test and benchmark uv run benchmark.py # compare with vllm uv run benchmark_vllm.py # enable profiling to save more metrics to a csv file # ENABLE_PROFILING=1 uv run benchmark_vllm.py ``` ## Results Test configuration: - PyTorch version: 2.7.1+cu128 - GPU: RTX 3090 x 1 (24GB) - Model: google/gemma-2-2b - Workload: 512 requests, max 512 input tokens, variable output tokens (128-512) - Configs tested: vLLM at 50% & 90% GPU memory, flex-nano-vllm with same page allocation as vLLM | Implementation | Output Tokens/s | Request/s | Total Throughput* | |---------------|----------------|-----------|------------------| | vLLM v1, 90% GPU memory, high batch size† | 3,840 | 17.67 | 7,234 | | vLLM v1, 90% GPU memory | 3,772 | 15.26 | 6,401 | | flex-nano-vllm, 90% GPU memory, high batch size† | 3,440 | 14.30 | 5,817 | | flex-nano-vllm, 90% GPU memory | 3,076 | 13.06 | 5,382 | | vLLM v1, 50% GPU memory | 3,020 | 13.74 | 5,448 | | flex-nano-vllm, 50% GPU memory | 2,313 | 9.96 | 4,068 | *Total throughput includes both input and output tokens † High batch size means max_num_seqs=512 in vllm (maximum allowed concurrency) ![Performance Comparison](tokens_per_second_comparison.png) ## License This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. Third-party code incorporated in this project retains its original licenses. See [THIRD_PARTY_LICENSES.md](THIRD_PARTY_LICENSES.md) for details. ## Acknowledgments - [GeeeekExplorer/nano-vllm](https://github.com/GeeeekExplorer/nano-vllm): this project is inspired by nano-vllm. - [pytorch-labs/attention-gym](https://github.com/pytorch-labs/attention-gym): The paged attention implementation is based on attention-gym. - [huggingface/transformers](https://github.com/huggingface/transformers): I copied the gemma2 model from transformers and modified it to use flex attention / paged attention. - [vllm-project/vllm](https://github.com/vllm-project/vllm): vLLM has support for flex attention backend, which helped me find a useful flag in flex_attention. ================================================ FILE: THIRD_PARTY_LICENSES.md ================================================ # Third Party Licenses This project incorporates code from third-party open source projects. The following licenses apply to the respective components: ## Hugging Face Transformers This project includes modified code from the transformers project: - **Source**: https://github.com/huggingface/transformers - **Files**: `modeling_gemma2.py` (Gemma2 model implementation) - **License**: Apache License 2.0 - **Copyright**: Copyright 2024 Google Inc. HuggingFace Inc. team ### Apache License 2.0 Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ## Attention Gym This project includes modified code from the attention-gym project: - **Source**: https://github.com/pytorch-labs/attention-gym - **Files**: `paged_attention.py` and related components - **License**: BSD 3-Clause License - **Copyright**: Copyright (c) 2023, Driss Guessous ### BSD 3-Clause License Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: benchmark.py ================================================ """ This script is used to test the correctness and benchmark the paged attention implementation. simplified interface of Inference class: Usage: llm = Inference(...) sequences = [Sequence(text) for text in texts] llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True, capture_cudagraph=True) outputs = debug_print_outputs(sequences, tokenizer) ## correctness: 1. run with different 2 sequences with different lengths, the output should match huggingface .generate() this output has been verified to be correct outside of this script 2. run the same sequence with .generate() and paged attention, the output should match this is to test the paged attention Inference class does not have side effects that can alter the output across .generate() calls ### correctness with dynamic batching 3. run with different number of requests, and the same 2 sequences are mixed in the batch the output should match 1. ### correctness with cuda graph 4. after cudagraph capture, run some batch of requests, and the same 2 sequences are mixed in the batch the output should match 1. ### tests we don't cover, but might be useful to have - PageTable unit tests - tests with output length longer than one page (128 tokens) """ from transformers import AutoTokenizer import time from flex_nano_vllm import Gemma2ForCausalLM from flex_nano_vllm.inference import Inference, Sequence, SamplingParams from benchmark_vllm import generate_benchmark_data, long_prompt, short_prompt import torch from rich.console import Console from torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler console = Console() torch.set_float32_matmul_precision("high") def get_profiler_context(): activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] profiler_context = profile( activities=activities, schedule=schedule(wait=0, warmup=10, active=10, repeat=10), on_trace_ready=tensorboard_trace_handler("trace_dir"), record_shapes=False, profile_memory=False, with_stack=True, with_flops=False, ) return profiler_context # long_prompt and short_prompt are now imported from bench_utils def debug_print_outputs(sequences, tokenizer, slice=slice(None), reference=None, prefix_match=False): results = [] for i, seq in enumerate(sequences[slice]): output_ids = seq.output_ids output_decoded = tokenizer.decode(output_ids, skip_special_tokens=True) input_decoded = tokenizer.decode(seq.input_ids, skip_special_tokens=False) # If reference provided, only print on mismatch if reference is not None and i < len(reference): # Check for exact match or prefix match if prefix_match: matches = output_decoded.startswith(reference[i]) else: matches = output_decoded == reference[i] if matches: match_type = "prefix match" if prefix_match else "match" console.print(f"i={i} ✓ {match_type}", style="green", markup=False) else: mismatch_type = "PREFIX MISMATCH" if prefix_match else "MISMATCH" console.print(f"i={i} ✗ {mismatch_type}", style="red bold", markup=False) console.print(f" expected: {reference[i][:32]}...", style="red", markup=False) console.print(f" got: {output_decoded[:32]}...", style="red", markup=False) console.print(f" input: {input_decoded}", style="dim", markup=False) else: # Normal detailed output when no reference console.print(f"{i=} {input_decoded=} {output_decoded[:32]=}", style="bold ", markup=False) results.append(output_decoded) return results if __name__ == "__main__": # Load model and tokenizer model_id = "google/gemma-2-2b" tokenizer = AutoTokenizer.from_pretrained(model_id) model = Gemma2ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval() # model = torch.compile(model) B = 8 max_new_tokens = 8 # settings to match vLLM (see benchmark_vllm.py) paged_attn_max_batch_size = 256 # match vLLM max_num_seqs=256 max_seq_length = 2048 max_input_length = 1024 token_allocation = 45_360 # 50% memory usage, this number is derived by running uv run benchmark_vllm.py and checking the logs (on a 3090) token_allocation = 140_432 # 90% memory usage, this number is derived by running uv run benchmark_vllm.py and checking the logs (on a 3090) prefill_length_limit = 1024 * 8 # helps control peak memory usage for prefill page_size = 128 n_pages = int(token_allocation) // page_size print("initializing vllm inference") llm = Inference( model, tokenizer, max_batch_size=paged_attn_max_batch_size, max_seq_length=max_seq_length, n_pages=n_pages, kernel_options={"BLOCK_M": 32, "BLOCK_N": 32}, prefill_length_limit=prefill_length_limit, ) print("test 1") ## test 1 sequences = [Sequence(long_prompt), Sequence(short_prompt)] llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True) results = debug_print_outputs(sequences, tokenizer) del sequences torch.cuda.empty_cache() print("test 2, same batch") sequences = [Sequence(long_prompt), Sequence(short_prompt)] llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True) results2 = debug_print_outputs(sequences, tokenizer, reference=results) for i in range(len(results)): assert results[i] == results2[i], f"{i=}, {results[i]=}, {results2[i]=}" del sequences torch.cuda.empty_cache() print("test 2.1: reverse order") sequences = [Sequence(short_prompt), Sequence(long_prompt)] llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True) # Debug in the right order: [long_prompt, short_prompt] to match reference reordered_sequences = [sequences[1], sequences[0]] # [long_prompt, short_prompt] results21 = debug_print_outputs(reordered_sequences, tokenizer, reference=results) for i in range(len(results)): assert results[i] == results21[i], f"{i=}, {results[i]=}, {results21[i]=}" del sequences torch.cuda.empty_cache() print("test 3: batch with other sequence") sequences = [Sequence(short_prompt), Sequence(long_prompt), Sequence("hi")] llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True) # Debug just the sequences we care about in the right order: [long_prompt, short_prompt] comparison_sequences = [sequences[1], sequences[0]] # [long_prompt, short_prompt] results3 = debug_print_outputs(comparison_sequences, tokenizer, reference=results) del sequences for i in range(len(results)): assert results[i] == results3[i], f"{i=}, {results[i]=}, {results3[i]=}" torch.cuda.empty_cache() print("test 4: batch with other sequence, capture cudagraph") sequences = [ Sequence("this is a test messaage hello "), Sequence(short_prompt), Sequence("test"), Sequence(long_prompt), Sequence("hello world "), ] llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=max_new_tokens), use_tqdm=True, capture_cudagraph=True) # Debug just the sequences we care about in the right order: [long_prompt, short_prompt] comparison_sequences = [sequences[3], sequences[1]] # [long_prompt, short_prompt] results4 = debug_print_outputs(comparison_sequences, tokenizer, reference=results) del sequences for i in range(len(results)): assert results[i] == results4[i], f"{i=}, {results[i]=}, {results4[i]=}" torch.cuda.empty_cache() # Generate test batch for cudagraph test_requests = generate_benchmark_data(tokenizer, n_requests=4, max_input_length=max_input_length) sequences = [Sequence(req.text) for req in test_requests] llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=16), use_tqdm=True, capture_cudagraph=True) # Just capture cudagraph, no need to show verbose output # debug_print_outputs(sequences, tokenizer) print("after cudagraph") test_requests2 = generate_benchmark_data(tokenizer, n_requests=4, max_input_length=max_input_length) sequences = [Sequence(req.text) for req in test_requests2] + [Sequence(long_prompt), Sequence(short_prompt)] print("replay cudagraph, & profile") with get_profiler_context() as prof: llm.generate(sequences, sampling_params=SamplingParams(max_new_tokens=16), use_tqdm=True, profiler=prof) results_prefill = debug_print_outputs(sequences, tokenizer, slice=slice(-2, None), reference=results, prefix_match=True) for i in range(len(results_prefill)): assert results_prefill[i][: len(results[i])] == results[i], f"{i=}, {results[i]=}, {results_prefill[i][:len(results[i])]=}" del sequences torch.cuda.empty_cache() ## benchmark throughput n_requests = 512 max_input_length = 512 # Use shared benchmark data generation benchmark_requests = generate_benchmark_data( tokenizer, n_requests=n_requests, max_input_length=max_input_length, ) # Convert to flex-nano-vllm format sequences = [Sequence(req.text) for req in benchmark_requests] sampling_params = [SamplingParams(max_new_tokens=req.max_new_tokens) for req in benchmark_requests] print("\n--- RUNNING BENCHMARK ---") # Reset memory stats to track benchmark-specific usage torch.cuda.reset_peak_memory_stats() start_time = time.time() llm.generate(sequences, sampling_params=sampling_params, use_tqdm=False, save_metrics_csv=True, print_stats=True) total_time = time.time() - start_time total_output_length = sum(seq.output_length for seq in sequences) total_input_length = sum(len(seq.input_ids) for seq in sequences) # Get memory usage peak_memory_mb = torch.cuda.max_memory_allocated() / 1024 / 1024 current_memory_mb = torch.cuda.memory_allocated() / 1024 / 1024 print("\n--- PERFORMANCE METRICS ---") print(f"Total time: {total_time:.2f}s") print(f"Throughput: {total_output_length / total_time:.1f} tokens/s") print(f"Request throughput: {len(sequences) / total_time:.2f} req/s") print(f"Total throughput (prompt+new): {(total_input_length + total_output_length) / total_time:.1f} tokens/s") print(f"Peak memory: {peak_memory_mb:.1f} MB") print(f"Current memory: {current_memory_mb:.1f} MB") print("\nafter benchmark") results_final = debug_print_outputs(sequences, tokenizer, slice=slice(n_requests, n_requests + 2), reference=results, prefix_match=True) # Verify correctness for i in range(len(results_final)): assert results[i] == results_final[i][: len(results[i])], f"{i=}, {results[i]=}, {results_final[i][:len(results[i])]=}" ================================================ FILE: benchmark_vllm.py ================================================ # /// script # requires-python = ">=3.12" # dependencies = [ # "vllm", # "transformers", # "datasets", # "matplotlib", # "tqdm", # "pandas", # ] # /// # Usage: # Run benchmark with minimal overhead: python bench_utils.py # Run benchmark with profiling: ENABLE_PROFILING=1 python bench_utils.py # NOTE: this script serves 2 purposes: # 1. it can be used to benchmark vLLM's performance, run with isolated inline dependencies. # 2. outside of __main__, it contains utils for producing the same payload for benchmarking. from tqdm import tqdm from datasets import load_dataset import random from dataclasses import dataclass @dataclass class BenchmarkRequest: """Simple, framework-agnostic request data""" text: str max_new_tokens: int # Standard prompts used in benchmarks long_prompt = """ The 12 months of the year are: January, February, March, """.strip() short_prompt = "The first 20 prime numbers are: 2, 3," def generate_benchmark_data(tokenizer, n_requests=512, max_input_length=512, min_tokens=128, max_tokens=512): """Generate benchmark data by skipping prompts that are too long.""" from datasets import load_dataset data = load_dataset("Open-Orca/OpenOrca")["train"] benchmark_requests = [] attempts = 0 while len(benchmark_requests) < n_requests: # Deterministic sampling using hash idx = hash(f"req_{attempts}") % len(data) system_prompt = data[idx]["system_prompt"] or "" question = data[idx]["question"] or "" prompt = f"{idx}: {system_prompt} {question}".strip() # Check length and skip if too long tokens = tokenizer.encode(prompt) if len(tokens) <= max_input_length: prompt_hash = hash(prompt) max_new_tokens = min_tokens + (abs(prompt_hash >> 16) % (max_tokens - min_tokens + 1)) benchmark_requests.append(BenchmarkRequest(text=prompt, max_new_tokens=max_new_tokens)) attempts += 1 if attempts > n_requests * 10: # Safety valve to prevent infinite loop break # Add standard prompts benchmark_requests.extend([ BenchmarkRequest(text=long_prompt, max_new_tokens=max_tokens), BenchmarkRequest(text=short_prompt, max_new_tokens=max_tokens) ]) return benchmark_requests def print_step_stats(steps, name): """Helper to print timing statistics for a collection of steps.""" if not steps: return print(f"\n{name}:") print(f" Count: {len(steps)}") print(f" Total: {sum(steps):.4f}s") print(f" Mean: {sum(steps)/len(steps):.4f}s") print(f" Min: {min(steps):.4f}s") print(f" Max: {max(steps):.4f}s") def generate_with_timing(llm, sequences, sampling_params, collect_detailed_metrics=False): """ Generate with timing, optionally collecting detailed metrics. Note: We track total step time rather than trying to separate prefill/decode because vLLM can do both types of work within a single step, making such separation misleading for performance analysis. """ outputs = [] total_step_time = 0.0 step_times = [] # Optional detailed metrics metrics_data = {} if not collect_detailed_metrics else { 'steps': [], 'requests_running': [], 'requests_waiting': [], 'preemptions': [] } # Add requests for i, prompt in enumerate(sequences): sp = sampling_params[i] if isinstance(sampling_params, list) else sampling_params llm.llm_engine.add_request(str(i), prompt, sp) step_count = 0 while llm.llm_engine.has_unfinished_requests(): step_start = time.perf_counter() step_outputs = llm.llm_engine.step() step_duration = time.perf_counter() - step_start step_count += 1 total_step_time += step_duration step_times.append(step_duration) # Collect outputs for output in step_outputs: if output.finished: outputs.append(output) # Optional detailed metrics if collect_detailed_metrics: metrics = llm.llm_engine.get_metrics() metrics_data['steps'].append(step_count) # Extract key metrics running = waiting = preemptions = 0 for metric in metrics: if "requests_running" in metric.name: running = metric.value elif "requests_waiting" in metric.name: waiting = metric.value elif "preemptions" in metric.name: preemptions = metric.value metrics_data['requests_running'].append(running) metrics_data['requests_waiting'].append(waiting) metrics_data['preemptions'].append(preemptions) return total_step_time, step_times, outputs, metrics_data if __name__ == "__main__": import os import time os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile" os.environ['VLLM_USE_V1'] = '1' ENABLE_PROFILING = os.environ.get("ENABLE_PROFILING", "0") == "1" from vllm import LLM, SamplingParams as VLLMSamplingParams from transformers import AutoTokenizer MODEL_ID = "google/gemma-2-2b" llm = LLM(MODEL_ID, dtype="bfloat16", gpu_memory_utilization=0.9, max_num_seqs=256, max_model_len=2048, disable_log_stats=False) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) n_requests = 512 max_input_length = 512 benchmark_requests = generate_benchmark_data(tokenizer, n_requests, max_input_length) sequences = [req.text for req in benchmark_requests] sampling_params_list = [ VLLMSamplingParams(temperature=0.0, top_p=1.0, max_tokens=req.max_new_tokens) for req in benchmark_requests ] # Warmup llm.generate(["warmup"], VLLMSamplingParams(max_tokens=1), use_tqdm=False) print(f"\n--- RUNNING {'WITH' if ENABLE_PROFILING else 'WITHOUT'} DETAILED METRICS ---") # Reset memory stats to track benchmark-specific usage start_time = time.time() total_step_time, step_times, outputs, metrics_data = generate_with_timing( llm, sequences, sampling_params_list, collect_detailed_metrics=ENABLE_PROFILING ) total_time = time.time() - start_time total_output_length = sum(len(o.outputs[0].token_ids) for o in outputs) prompt_tok = sum(len(o.prompt_token_ids) for o in outputs) print_step_stats(step_times, "\nstep") # Get memory usage print("\n--- PERFORMANCE METRICS ---") print(f"Total time: {total_time:.2f}s") print(f"Throughput: {total_output_length / total_time:.1f} tokens/s") print(f"Request throughput: {len(sequences) / total_time:.2f} req/s") print(f"Total throughput (prompt+new): {(prompt_tok + total_output_length) / total_time:.1f} tokens/s") if ENABLE_PROFILING and metrics_data: print("\n--- DETAILED METRICS ---") import pandas as pd pd.DataFrame(metrics_data).to_csv('vllm_metrics.csv', index=False) print("Metrics saved to 'vllm_metrics.csv'") ================================================ FILE: flex_nano_vllm/__init__.py ================================================ """flex-nano-vllm - Flex-attention based nano-vllm implementation for fast PaliGemma inference.""" from .modeling_gemma2 import Gemma2ForCausalLM from .inference import Inference, Sequence __version__ = "0.1.0" __all__ = ["Gemma2ForCausalLM", "Inference", "Sequence"] ================================================ FILE: flex_nano_vllm/inference.py ================================================ from collections import deque import time import torch from tqdm import tqdm from torch.nn.attention.flex_attention import BlockMask import torch.nn.attention.flex_attention import torch.nn.functional as F from rich.console import Console from flex_nano_vllm.paged_attention import PageTable, PagedKVCache from dataclasses import dataclass console = Console() print(f"torch version: {torch.__version__}") @dataclass class SamplingParams: max_new_tokens: int = -1 def sample(logits_BV, greedy=True, to_cpu=False): # NOTE: use greedy=True to ensure deterministic sampling assert logits_BV.ndim == 2 B, V = logits_BV.shape probs = torch.softmax(logits_BV, dim=-1) if not greedy: indices = torch.multinomial(probs, num_samples=1) # shape: [B, 1] logits = torch.gather(logits_BV, dim=-1, index=indices) probs = torch.gather(probs, dim=-1, index=indices) else: probs, indices = torch.topk(probs, k=1, dim=-1) logits = torch.gather(logits_BV, dim=-1, index=indices) if to_cpu: indices = indices.to("cpu", non_blocking=True).view(B) logits = logits.to("cpu", non_blocking=True).view(B) probs = probs.to("cpu", non_blocking=True).view(B) torch.cuda.synchronize() return indices.tolist(), logits.tolist(), probs.tolist() class Sequence: def __init__(self, text: str): self.done = False self.text = text self._output_ids = [] self._output_logits = [] self._output_probs = [] self.input_ids = [] self.finished = False self.input_length = None self.inputs = None def add_next_token(self, token_id: torch.Tensor, logits: torch.Tensor, probs: torch.Tensor): #assert token_id.ndim == 0 #assert logits.ndim == 0 self._output_ids.append(token_id) self._output_logits.append(logits) self._output_probs.append(probs) def copy(self): return Sequence(self.text) @property def output_ids(self): return torch.tensor(self._output_ids, dtype=torch.int64) @property def output_logits(self): return torch.tensor(self._output_logits, dtype=torch.float32) @property def output_probs(self): return torch.tensor(self._output_probs, dtype=torch.float32) @property def output_length(self): return len(self._output_ids) @property def total_length(self): return self.input_length + self.output_length @property def total_token_ids(self): if self.output_length: return torch.cat([self.input_ids, self.output_ids], dim=0) return self.input_ids @property def last_token_id(self): return self._output_ids[-1] def process_sampling_params(sequences: list[Sequence], sampling_params: SamplingParams | list[SamplingParams] | None): if sampling_params is None: sampling_params = SamplingParams() if isinstance(sampling_params, SamplingParams): sampling_params = [sampling_params] * len(sequences) assert len(sampling_params) == len(sequences), "sampling_params must be a list of the same length as sequences" for seq, param in zip(sequences, sampling_params): seq.params = param class Inference: def __init__(self, model, tokenizer, max_batch_size, max_seq_length, n_pages, page_size=128, prefill_length_limit=-1, kernel_options=None): self.page_table = PageTable(n_pages=n_pages, page_size=page_size, max_batch_size=max_batch_size) self.model = model self.tokenizer = tokenizer self.eos_token_id = tokenizer.eos_token_id # cache this because it's not efficient to call tokenizer.eos_token_id every time self.device = model.device assert max_seq_length % page_size == 0, "max_seq_length must be divisible by page_size" self.max_seq_length = max_seq_length self.max_batch_size = max_batch_size self.kernel_options = kernel_options self.prefill_length_limit = prefill_length_limit # NOTE: control the peak memory usage of prefill for layer in self.model.model.layers: layer.self_attn.kv_cache = PagedKVCache( self.page_table, n_heads=self.model.model.config.num_key_value_heads, head_dim=self.model.model.config.head_dim, dtype=self.model.dtype, ).to(self.device, non_blocking=True) self.cudagraph_captured = False self.input_pos = torch.zeros(self.max_batch_size, dtype=torch.int32, pin_memory=True).to(self.device, non_blocking=True) self.block_mask = self.page_table.create_causal_blockmask(B=self.max_batch_size, L=self.max_seq_length) def _prefill_sequences( self, input_ids: torch.Tensor, input_pos: torch.Tensor, batch_idx_tensor: torch.Tensor, logits_to_keep: torch.Tensor ) -> torch.Tensor: # 1. no cuda graph # 2. construct block mask and apply it in logical space # 3. only write to kv cache, no read # NOTE: for batch/packed prefill, we need to pass batch_idx_tensor as [1, L] # input_ids is [1, L], concatenated from all sequences # batch_idx_tensor is [1, L] # position_ids is [1, L] # logits_to_keep is [num_sequences] instead of [1] ## padding: if there's padding # input_ids should be padded with any valid token id # input_pos should be padded with 0 # batch_idx_tensor should be padded with 0 # reserved in page table assert input_ids.shape[0] == 1, "input_ids must be [1, L]" assert input_pos.shape == input_ids.shape, f"input_pos must be [1, L], got {input_pos.shape=}, {input_ids.shape=}" assert batch_idx_tensor.shape == input_ids.shape, f"batch_idx_tensor must be [1, L], got {batch_idx_tensor.shape=}, {input_ids.shape=}" mask = self.page_table.create_prefill_blockmask_no_paging(batch_idx_tensor) outputs = self.model.model( input_ids=input_ids, position_ids=input_pos + 1, # NOTE: gemma2 uses 1-based position ids # logits_to_keep=logits_to_keep, flex_attn_block_mask=mask, flex_attn_input_pos=input_pos, flex_attn_batch_idx=batch_idx_tensor, flex_attn_kernel_options=self.kernel_options | {"FORCE_USE_FLEX_ATTENTION": True}, # NOTE: force torch compile to not use flash decoding code path ) return self.model.lm_head(outputs.last_hidden_state[:, logits_to_keep, :]) """ outputs = self.model( input_ids=input_ids, position_ids=input_pos + 1, # NOTE: gemma2 uses 1-based position ids logits_to_keep=logits_to_keep, flex_attn_block_mask=mask, flex_attn_input_pos=input_pos, flex_attn_batch_idx=batch_idx_tensor, flex_attn_kernel_options=self.kernel_options | {'FORCE_USE_FLEX_ATTENTION': True}, # NOTE: force torch compile to not use flash decoding code path ) return outputs.logits """ def prefill_sequences(self, sequences: list[Sequence]) -> torch.Tensor: input_ids = torch.cat([seq.total_token_ids for seq in sequences], dim=0) input_pos = torch.cat([torch.arange(seq.total_length, dtype=torch.long) for seq in sequences], dim=0) batch_idx_tensor = torch.cat([torch.ones(seq.total_length, dtype=torch.long) * seq.batch_idx for seq in sequences], dim=0) input_lengths = torch.tensor([seq.total_length for seq in sequences], dtype=torch.int32).to(self.device, non_blocking=True) logits_to_keep = input_lengths.cumsum(dim=0) - 1 num_pad = 128 - input_ids.shape[0] % 128 if num_pad > 0: input_ids = F.pad(input_ids.view(-1), (0, num_pad), mode="constant", value=0) input_pos = F.pad(input_pos.view(-1), (0, num_pad), mode="constant", value=0) batch_idx_tensor = F.pad(batch_idx_tensor.view(-1), (0, num_pad), mode="constant", value=0) # logits_to_keep is not padded, it should have shape [num_sequences] input_ids = input_ids.view(1, -1).to(self.device, non_blocking=True) input_pos = input_pos.view(1, -1).to(self.device, non_blocking=True) batch_idx_tensor = batch_idx_tensor.view(1, -1).to(self.device, non_blocking=True) logits_to_keep = logits_to_keep.view(-1).to(self.device, non_blocking=True) logits = self._prefill_sequences(input_ids, input_pos, batch_idx_tensor, logits_to_keep) return logits def get_decoding_block_mask(self, batch_idx: torch.Tensor): """ Args: batch_idx: [B] Returns: block_mask: [B, H, ROWS=1, MAX_BLOCKS_IN_COL] input_pos: [B] This function slices the full block mask self.block_mask: [max_batch_size, H, MAX_BLOCKS_IN_ROW, MAX_BLOCKS_IN_COL] using self.input_pos: [max_batch_size] and batch_idx: [B] """ # NOTE: this function is entirely in logical space def causal_offset(off: torch.Tensor): def offset(b, h, q_idx, kv_idx): return q_idx + off[b] >= kv_idx return offset block_mask = self.block_mask input_pos = self.input_pos[batch_idx] # batch_idx: [B], input_pos: [B] assert batch_idx.ndim == 1, "batch_idx must be 1D" assert input_pos.ndim == 1, "input_pos must be 1D" (B,) = batch_idx.shape input_block_idx = input_pos // block_mask.BLOCK_SIZE[0] # [B] kv_num_blocks = block_mask.kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1) kv_indices = block_mask.kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1) full_kv_num_blocks, full_kv_indices = None, None if block_mask.full_kv_num_blocks is not None: full_kv_num_blocks = block_mask.full_kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1) # noqa full_kv_indices = block_mask.full_kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1) # noqa seq_length = (1, block_mask.seq_lengths[1]) mask = BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, BLOCK_SIZE=block_mask.BLOCK_SIZE, mask_mod=causal_offset(input_pos), seq_lengths=seq_length, ) return mask, input_pos def _decode_step(self, batch_idx: torch.Tensor, input_ids: torch.Tensor): B = input_ids.shape[0] mask, input_pos = self.get_decoding_block_mask(batch_idx) mask = self.page_table.convert_logical_block_mask(mask, batch_idx) outputs = self.model( input_ids=input_ids.view(B, 1), position_ids=(input_pos + 1).view(B, 1), # NOTE: position_ids is needed for decoding. For Gemma2, it's 1-based flex_attn_block_mask=mask, flex_attn_input_pos=input_pos.view(B, 1), flex_attn_batch_idx=batch_idx.view(-1), flex_attn_kernel_options=self.kernel_options, ) return outputs.logits def decode_step(self, batch_idx: torch.Tensor, input_ids: torch.Tensor, input_pos: torch.Tensor): assert input_ids.ndim == 1, "input_ids must be 1D" assert batch_idx.ndim == 1, "batch_idx must be 1D" assert input_ids.shape[0] == batch_idx.shape[0], "input_ids and batch_idx must have the same length" self.input_pos.zero_() self.input_pos[batch_idx] = input_pos if not self.cudagraph_captured: return self._decode_step(batch_idx, input_ids) else: bs = input_ids.size(0) graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] graph_vars = self.graph_vars for k, v in graph_vars.items(): if k != "outputs": v.zero_() # batch idx is 0 for undefined batch idx, so we never overwrite any kv-cache graph_vars["input_ids"][:bs] = input_ids graph_vars["batch_idx"][:bs] = batch_idx graph.replay() replayed = graph_vars["outputs"][:bs] return replayed def _check_done(self, sequences: list[Sequence]): for seq in sequences: is_eos = seq.last_token_id == self.eos_token_id is_max_len = seq.input_length + seq.output_length >= self.max_seq_length is_max_new = seq.output_length == seq.params.max_new_tokens if is_eos or is_max_len or is_max_new: seq.finished = True self.done.append(seq) self.page_table.erase(seq.batch_idx) return [seq for seq in sequences if not seq.finished] def run_one_step(self): # try to prefill a new sequence, if we can schedule it batch = [] prefill_length_sum = 0 # NOTE: for each sequence, only allocate & reserve just enough pages for the current sequence length # We do not reserve pages for the future tokens. This allows maximizing the batch size of decoding. # We support preemption in case we run out of pages during decoding while ( self.waiting and self.page_table.can_reserve(self.waiting[0].total_length) # + self.waiting[0].params.max_new_tokens) and (not batch or self.prefill_length_limit == -1 or prefill_length_sum + self.waiting[0].total_length < self.prefill_length_limit) ): prefill_length_sum += self.waiting[0].total_length seq = self.waiting.popleft() batch_idx_int = self.page_table.allocate() batch_idx_tensor = torch.tensor([batch_idx_int], device=self.device, dtype=torch.long) self.page_table.reserve(batch_idx_int=batch_idx_int, batch_idx=batch_idx_tensor, seq_len=seq.total_length) # + seq.params.max_new_tokens seq.batch_idx = batch_idx_int batch.append(seq) if batch: logits_1LV = self.prefill_sequences(batch) next_token, logits, probs = sample(logits_1LV[0, :, :], to_cpu=True) for b in range(len(batch)): batch[b].add_next_token(next_token[b], logits[b], probs[b]) self.running.extend(self._check_done(batch)) return "prefill" # reserve new block for running sequences if needed # if a new block is needed, but there's no space, preempt the newest sequence # running: [oldest, ... , newest] waiting: [oldest, ... , newest] while self.running: seq = self.running.popleft() if self.page_table.capacity[seq.batch_idx] >= seq.total_length: # no need to reserve new pages batch.append(seq) elif self.page_table.can_reserve(seq.total_length, batch_idx_int=seq.batch_idx): # reserve new pages self.page_table.reserve( batch_idx_int=seq.batch_idx, batch_idx=torch.tensor([seq.batch_idx], device=self.device, dtype=torch.long), seq_len=seq.total_length, ) batch.append(seq) else: # no space to run this sequence, preempt the newest sequence self.running.appendleft(seq) # first put this sequence back newest = self.running.pop() # then pop the newest sequence self.waiting.appendleft(newest) self.page_table.erase(newest.batch_idx) B = len(batch) # now we do decoding batch_idx = torch.tensor([seq.batch_idx for seq in batch], dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True) input_ids = torch.tensor([seq.last_token_id for seq in batch], dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True) input_pos = torch.tensor([seq.total_length - 1 for seq in batch], dtype=torch.int32, pin_memory=True).to(self.device, non_blocking=True) self.counts.append(B) logits_BLV = self.decode_step(batch_idx, input_ids, input_pos) next_token, logits, probs = sample(logits_BLV[:, -1, :], to_cpu=True) for i in range(B): batch[i].add_next_token(next_token[i], logits[i], probs[i]) self.running = deque(self._check_done(batch)) return "decode" def tokenize(self, sequences: list[Sequence]): self.tokenizer.padding_side = "right" for seq in sequences: seq.input_ids = self.tokenizer([seq.text], return_tensors="pt")["input_ids"].squeeze(0) seq.input_length = seq.input_ids.shape[0] @torch.inference_mode() def generate( self, sequences: list[Sequence], use_tqdm=False, profiler=None, greedy=False, sampling_params: SamplingParams | list[SamplingParams] | None = None, capture_cudagraph=False, save_metrics_csv=False, print_stats=False, ): self.counts = [] self.metrics_data = {"step": [], "requests_running": [], "requests_waiting": [], "step_type": []} # preprocess the sequences self.tokenize(sequences) self.waiting = deque(sequences) self.running = deque() self.done = deque() process_sampling_params(sequences, sampling_params) if capture_cudagraph and not self.cudagraph_captured: self.capture_decode_cudagraph() self.cudagraph_captured = True total_sequences = len(self.waiting) times = [] step_count = 0 with tqdm(total=total_sequences, disable=not use_tqdm, desc="Generating") as pbar: prev_done = 0 while self.waiting or self.running: step_count += 1 # Track metrics before step self.metrics_data["step"].append(step_count) self.metrics_data["requests_running"].append(len(self.running)) self.metrics_data["requests_waiting"].append(len(self.waiting)) time_start = time.perf_counter() step_type = self.run_one_step() time_end = time.perf_counter() # Track step type self.metrics_data["step_type"].append(step_type) times.append({"step_type": step_type, "time": time_end - time_start}) if profiler: profiler.step() # Update progress bar based on newly completed sequences curr_done = len(self.done) if curr_done > prev_done: pbar.update(curr_done - prev_done) prev_done = curr_done if print_stats: self.print_time_stats(times) # Save metrics to CSV if requested if save_metrics_csv: import pandas as pd df = pd.DataFrame(self.metrics_data) df.to_csv("flex_nano_vllm_metrics.csv", index=False) print("Metrics saved to 'flex_nano_vllm_metrics.csv'") def capture_decode_cudagraph(self): """ capture cudagraph for decoding """ max_bs = self.max_batch_size input_ids = torch.zeros(max_bs, dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True) batch_idx = torch.arange(max_bs, dtype=torch.int64, pin_memory=True).to(self.device, non_blocking=True) # NOTE: here we use logits as the final output, but we can consider using last hidden state as the output outputs = torch.zeros((max_bs, 1, self.model.model.config.vocab_size), pin_memory=True).to(self.device, non_blocking=True) self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) # 8 vs 16 doesn't make much difference self.graphs = {} self.graph_pool = None for bs in reversed(self.graph_bs): print(f"capturing cudagraph for {bs} sequences") torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() # warmup outputs[:bs] = self._decode_step(batch_idx[:bs], input_ids[:bs]) # warmup # capture with torch.cuda.graph(graph, self.graph_pool): outputs[:bs] = self._decode_step(batch_idx[:bs], input_ids[:bs]) # capture if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[bs] = graph torch.cuda.synchronize() self.graph_vars = dict( input_ids=input_ids, batch_idx=batch_idx, outputs=outputs, ) # in our code, the page table tensors are modified in-place, so we don't need to put them in graph vars def print_time_stats(self, times): stats = {} for step in ["decode", "prefill"]: step_times = [t["time"] for t in times if t["step_type"] == step] stats[step] = { "count": len(step_times), "total": sum(step_times), "mean": sum(step_times) / len(step_times) if step_times else 0, "min": min(step_times) if step_times else 0, "max": max(step_times) if step_times else 0, } print("\nTime statistics by step type:") for step, metrics in stats.items(): print(f"\n{step}:") print(f" Count: {metrics['count']}") print(f" Total: {metrics['total']:.4f}s") print(f" Mean: {metrics['mean']:.4f}s") print(f" Min: {metrics['min']:.4f}s") print(f" Max: {metrics['max']:.4f}s") print(f"\nTotal time: {sum(t['time'] for t in times):.4f}s") ================================================ FILE: flex_nano_vllm/modeling_gemma2.py ================================================ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_gemma2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Union import torch import torch.nn as nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import check_model_inputs from transformers.models.gemma2.configuration_gemma2 import Gemma2Config from torch.nn.attention.flex_attention import flex_attention flex_attention = torch.compile(flex_attention, fullgraph=True) logger = logging.get_logger(__name__) class Gemma2RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.zeros(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()) # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) return output.type_as(x) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" class Gemma2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_activation] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, softcap: Optional[float] = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if scaling is None: scaling = module.head_dim**-0.5 key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if softcap is not None: attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * softcap if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar**-0.5 self.attention_dropout = self.config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.kv_cache = None def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, flex_attn_block_mask = None, flex_attn_input_pos = None, flex_attn_batch_idx = None, flex_attn_kernel_options = {'BLOCK_M': 32, 'BLOCK_N': 32, 'BLOCK_M1': 32, 'BLOCK_M2': 32, 'BLOCK_N1': 32, 'BLOCK_N2': 32, }, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # NOTE: this does not cover the sliding window case, but in my current usage, the sequence length does not exceed 4096 def soft_cap(score, b, h, q_idx, kv_idx): score = score / self.attn_logit_softcapping score = torch.tanh(score) score = score * self.attn_logit_softcapping return score if self.kv_cache is not None and flex_attn_input_pos is not None: key_states, value_states= self.kv_cache.update(flex_attn_input_pos, key_states, value_states, flex_attn_batch_idx) attn_output = flex_attention( query_states, key_states, value_states, #dropout=self.attention_dropout if self.training else 0.0, scale=self.scaling, block_mask=flex_attn_block_mask, score_mod=soft_cap, enable_gqa=True, kernel_options=flex_attn_kernel_options, ) attn_weights = None attn_output = attn_output.transpose(1, 2) # (B, H, N, E) -> (B, N, H, E) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Gemma2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.config = config self.attention_type = config.layer_types[layer_idx] self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) self.mlp = Gemma2MLP(config) self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @deprecate_kwarg("last_cache_position", version="4.53.0") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.pre_feedforward_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class Gemma2RotaryEmbedding(nn.Module): def __init__(self, config: Gemma2Config, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @auto_docstring class Gemma2PreTrainedModel(PreTrainedModel): config: Gemma2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Gemma2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_static_cache = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": Gemma2DecoderLayer, "attentions": Gemma2Attention, } def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, Gemma2RMSNorm): module.weight.data.fill_(1.0) @auto_docstring class Gemma2Model(Gemma2PreTrainedModel): def __init__(self, config: Gemma2Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Gemma2RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @check_model_inputs @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None and not self.training: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) # embed positions hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # normalized # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, position_embeddings=position_embeddings, attention_mask=None, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) @auto_docstring class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Gemma2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple @auto_docstring def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> CausalLMOutputWithPast: r""" Example: ```python >>> from transformers import AutoTokenizer, Gemma2ForCausalLM >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b") >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") >>> prompt = "What is your favorite condiment?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "What is your favorite condiment?" ```""" if self.training and self.config._attn_implementation != "eager": logger.warning_once( "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) if self.config.final_logit_softcapping is not None: logits = logits / self.config.final_logit_softcapping logits = torch.tanh(logits) logits = logits * self.config.final_logit_softcapping loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "Gemma2ForCausalLM", "Gemma2Model", "Gemma2PreTrainedModel", ] ================================================ FILE: flex_nano_vllm/paged_attention.py ================================================ # Adapted from attention-gym # Original source: https://github.com/pytorch-labs/attention-gym # License: BSD 3-Clause (see THIRD_PARTY_LICENSES.md) # Copyright (c) 2023, Driss Guessous # the original implementation has some bugs and has some feature that lives outside of the PageTable class from typing import Optional import torch from torch import Tensor from torch.nn.attention.flex_attention import ( _identity, _mask_mod_signature, _score_mod_signature, BlockMask, noop_mask, create_block_mask, ) create_block_mask = torch.compile(create_block_mask) def _cdiv(x: int | float | torch.Tensor, multiple: int | float | torch.Tensor): return (x + multiple - 1) // multiple class PagedKVCache(torch.nn.Module): def __init__(self, page_table, n_heads, head_dim, dtype): super().__init__() cache_shape = (1, n_heads, page_table.n_pages * page_table.page_size, head_dim) self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) self.page_table = page_table def update(self, input_pos, k_val, v_val, batch_idx=None): assert batch_idx is not None, "batch_idx is required for paged kv cache, are you using non-paged attention?" if batch_idx.ndim == 1: # batch_idx should be [B] (decode) return self.page_table.assign(batch_idx, input_pos, k_val, v_val, self.k_cache, self.v_cache) else: assert batch_idx.ndim == 2, "batch_idx must be 1D or 2D" # batch_idx should be [1, L] (batch prefill) return self.page_table.assign_prefill_no_paging(batch_idx, input_pos, k_val, v_val, self.k_cache, self.v_cache) class PageTable: """ PageTable is a modified version of PagedAttention from attention-gym. PageTable improves it by: - maintaining a cpu copy of the page table, to avoid device-to-host transfers - support batch prefill - fix the bug in the original code in mask_mod and score_mod by mapping physical batch index to logical batch index - subsuming the free_batch_idx into the page table, so we don't need to maintain it separately """ def __init__( self, n_pages: int, page_size: int, max_batch_size: int, device: str = "cuda", ): self.n_pages = n_pages self.page_size = page_size self.max_batch_size = max_batch_size self.device = device # page table: [logical_batch_idx, logical_block_idx] -> physical_page_idx self.page_table = -torch.ones((max_batch_size, self.n_pages), dtype=torch.int64, device=device) self.page_table[0, :] = 0 # page 0 is reserved for simpler code in assign_prefill_no_paging self.page_table_cpu = [[] for _ in range(max_batch_size)] self.capacity = [0 for _ in range(max_batch_size)] # capacity: batch_idx -> number of pages allocated * page size self.free_pages = list(reversed(range(1, n_pages))) # page 0 is reserved for simpler code in assign_prefill_no_paging self.free_batch_idx = list(reversed(range(1, max_batch_size))) # batch_idx 0 is reserved for no-op # [logical_batch_idx, physical_page_idx] -> logical_page_idx self.physical_to_logical = -torch.ones((max_batch_size, n_pages), dtype=torch.int64, device=device) def can_reserve(self, size: int, batch_idx_int: int | None = None) -> bool: """check if we can reserve new pages for an existing request or a new request, without gpu operations""" if batch_idx_int is None: # check if we can schedule a new request return self.pages_available * self.page_size >= size and len(self.free_batch_idx) > 0 else: # check if we can reserve new pages for an existing request return self.reserve(batch_idx_int, None, size, dry_run=True) def allocate(self) -> int: """allocate a new batch""" batch_idx = self.free_batch_idx.pop() self.capacity[batch_idx] = 0 self.physical_to_logical[batch_idx, :] = -1 self.page_table[batch_idx, :] = -1 return batch_idx @property def pages_available(self) -> int: return len(self.free_pages) def reserve(self, batch_idx_int: int, batch_idx: torch.Tensor, seq_len: int, dry_run: bool = False) -> bool: """ Requests the capacity of a given batch to be at least enough to hold `seq_len` elements. Args: batch_idx_int (int): batch index to be reserved; batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`. seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`. Returns: bool: True if the reservation was successful, False if the reservation was not successful (no space, and in this case, no update is done) """ if seq_len <= self.capacity[batch_idx_int]: return True num_pages_to_allocate = _cdiv(seq_len - self.capacity[batch_idx_int], self.page_size) can_allocate = num_pages_to_allocate <= self.pages_available if dry_run: return can_allocate if not can_allocate: raise RuntimeError( f"Cannot reserve {num_pages_to_allocate} pages for a sequence of length {seq_len} " f"in batch {batch_idx_int}. Only {self.pages_available} pages available. " f"Current capacity is {self.capacity[batch_idx_int]} tokens." ) start_page_idx = self.capacity[batch_idx_int] // self.page_size end_page_idx = start_page_idx + num_pages_to_allocate # find empty physical pages allocated_pages_list = self.free_pages[-num_pages_to_allocate:] allocated_pages = torch.tensor(allocated_pages_list, device=self.device) # update page table self.page_table[batch_idx, start_page_idx:end_page_idx] = allocated_pages # update metadata self.physical_to_logical[batch_idx, allocated_pages] = torch.arange( start_page_idx, end_page_idx, device=self.device, ) # update cpu side metadata self.page_table_cpu[batch_idx_int] += allocated_pages_list self.free_pages = self.free_pages[:-num_pages_to_allocate] self.capacity[batch_idx_int] += num_pages_to_allocate * self.page_size return True def erase(self, batch_idx: int) -> None: """ Removes a single batch from paged attention. Args: batch_idx (int): batch index to be removed; """ # NOTE: the GPU side data will only be reset/overwritten when we allocate it for a new batch self.free_batch_idx.append(batch_idx) allocated_pages_cpu = self.page_table_cpu[batch_idx] self.free_pages.extend(reversed(allocated_pages_cpu)) self.page_table_cpu[batch_idx] = [] def assign( self, batch_idx: torch.Tensor, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, ) -> None: """ Assigns new contents `val` to the storage `cache` at the location `batch_idx` and `input_pos`. Args: batch_idx (Tensor): batch index; shape :math:`(B)`. input_pos (Tensor): input positions to be assigned for the given batch; shape :math:`(B, S)`. val (Tensor): value to be assigned; shape :math:`(B, H, S, D)` cache (Tensor): the cache to store the values; shape:`(1, H, MAX_S, D)` """ if k_val.requires_grad: raise RuntimeError("val must not require gradient") B, H, S, K_D = k_val.shape _, H_cache, MAX_S, D_cache = k_cache.shape assert H_cache == H, "number of heads must match" assert MAX_S >= S, "cache must have enough space" assert D_cache == K_D, "hidden dim must match" assert input_pos.shape == (B, S), "input_pos must have the same shape as val" assert batch_idx.shape == (B,), "batch_idx must have one dimension only" V_D = v_val.shape[3] if B != batch_idx.shape[0]: raise RuntimeError(f"Expect val and batch_idx have the same batch size but got B={B} and B={batch_idx.shape[0]}.") if H != k_cache.shape[1]: raise RuntimeError(f"Expect val and cache has the same number of heads but got H={H} and H={k_cache.shape[1]}.") if S != input_pos.shape[1]: raise RuntimeError(f"Expect val and input_pos has the same length but got S={S} and S={input_pos.shape[0]}.") if K_D != k_cache.shape[3]: raise RuntimeError(f"Expect k_val and k_cache has the same hidden dim but got D={K_D} and D={k_cache.shape[3]}.") if V_D != v_cache.shape[3]: raise RuntimeError(f"Expect v_val and v_cache has the same hidden dim but got D={V_D} and D={v_cache.shape[3]}.") # find address logical_block_idx = input_pos // self.page_size # [B, S] logical_block_offset = input_pos % self.page_size # [B, S] # NOTE: this code path is only used for decoding. For batch prefill, use assign_prefill_no_paging() instead physical_block_idx = torch.gather(self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)).to(torch.int32) # [B, S] addr = (physical_block_idx * self.page_size + logical_block_offset).view(-1) # [B*S] k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D) v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D) k_cache[:, :, addr, :] = k_val v_cache[:, :, addr, :] = v_val return k_cache, v_cache def convert_logical_block_mask( self, block_mask: BlockMask, batch_idx: Optional[torch.Tensor] = None, ) -> BlockMask: """ Converts a logical block mask by mapping its logical kv indices to the corresponding physical kv indices. Args: block_mask (BlockMask): logical block mask; kv_indices shape :math:`(B, H, ROWS, MAX_BLOCKS_IN_COL)`. batch_idx (Tensor): batch index corresponding to the block_mask batch dimension. This provides flexibility to convert a block mask with smaller batch size than the page table; shape :math:`(B)`. """ B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape if block_mask.BLOCK_SIZE[1] != self.page_size: raise RuntimeError( f"Expect block_mask has the same column block size as page_sizebut got size={block_mask.BLOCK_SIZE[1]} and size={self.page_size}" ) device = block_mask.kv_num_blocks.device if batch_idx is None: batch_idx = torch.arange(B, device=device) assert batch_idx.ndim == 1, "batch_idx must be a 1D tensor" assert batch_idx.shape[0] == B, "batch_idx must have the same shape as block_mask" assert B <= self.max_batch_size, "batch_idx must be less than or equal to max_batch_size" page_table = self.page_table[batch_idx] def transform(num_blocks, indices): """ transform the block mask from [B, H, num_q_blocks, num_logical_kv_blocks] to [B, H, num_q_blocks, num_physical_kv_blocks] kv_num_blocks: [B, H, num_q_blocks] -> unchanged kv_indices: [B, H, num_q_blocks, num_logical_kv_blocks] -> [B, H, num_q_blocks, num_physical_kv_blocks] """ if num_blocks is None: return None, None new_kv_num_blocks = num_blocks.clone() new_kv_indices = torch.zeros((B, H, ROWS, self.n_pages), dtype=torch.int32, device=device) new_kv_indices[:, :, :, :MAX_BLOCKS_IN_COL] = ( torch.gather(page_table, 1, indices.view(B, -1).to(torch.int64)).view(block_mask.kv_indices.shape).to(torch.int32) ) return new_kv_num_blocks, new_kv_indices new_kv_num_blocks, new_kv_indices = transform(block_mask.kv_num_blocks, block_mask.kv_indices) new_full_kv_num_blocks, new_full_kv_indices = transform(block_mask.full_kv_num_blocks, block_mask.full_kv_indices) new_mask_mod = self.get_mask_mod(block_mask.mask_mod, batch_idx) seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size) return BlockMask.from_kv_blocks( new_kv_num_blocks, new_kv_indices, new_full_kv_num_blocks, new_full_kv_indices, block_mask.BLOCK_SIZE, new_mask_mod, seq_lengths=seq_lengths, ) def get_logical_kv_idx(self, physical_batch_idx: torch.Tensor, physical_kv_idx: torch.Tensor, batch_idx: torch.Tensor): logical_batch_idx = batch_idx[physical_batch_idx] physical_kv_block = physical_kv_idx // self.page_size physical_kv_offset = physical_kv_idx % self.page_size logical_block_idx = self.physical_to_logical[logical_batch_idx, physical_kv_block] logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset is_valid = logical_block_idx >= 0 safe_logical_kv_idx = logical_kv_idx.clamp(min=0) return is_valid, safe_logical_kv_idx def get_mask_mod(self, mask_mod: Optional[_mask_mod_signature], batch_idx: torch.Tensor) -> _mask_mod_signature: """ Converts a mask_mod based on mapping from the physical block index to the logical block index. Args: mask_mod (_mask_mod_signature): mask_mod based on the logical block index. """ if mask_mod is None: mask_mod = noop_mask def new_mask_mod( b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ): is_valid, safe_logical_kv_idx = self.get_logical_kv_idx(b, physical_kv_idx, batch_idx) return torch.where(is_valid, mask_mod(b, h, q_idx, safe_logical_kv_idx), False) return new_mask_mod # NOTE: not used in the current codebase def get_score_mod(self, score_mod: Optional[_score_mod_signature], batch_idx: torch.Tensor) -> _score_mod_signature: """ Converts a score_mod based on mapping from the physical block index to the logical block index. Args: score_mod (_score_mod_signature): score_mod based on the logical block index. """ if score_mod is None: score_mod = _identity def new_score_mod( score: torch.Tensor, b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ): is_valid, safe_logical_kv_idx = self.get_logical_kv_idx(b, physical_kv_idx, batch_idx) return torch.where( is_valid, score_mod(score, b, h, q_idx, safe_logical_kv_idx), float("-inf"), ) return new_score_mod def create_causal_blockmask(self, B, L): """A minimal, unoptimized causal block mask creation function""" def causal(b, h, q_idx, kv_idx): return q_idx >= kv_idx return create_block_mask(causal, B=B, H=None, Q_LEN=L, KV_LEN=L, BLOCK_SIZE=self.page_size, device=self.device) def create_prefill_blockmask_no_paging(self, batch_idx: Tensor, BLOCK_SIZE: int = 128): """ there's no prefix sharing implemented, batch_idx is the document id, batch_idx is not guaranteed to be sorted """ assert batch_idx.ndim == 2, "batch_idx must be a 2D tensor" assert batch_idx.shape[0] == 1, "batch_idx must have batch size 1" L = batch_idx.shape[1] docs = batch_idx.view(-1) def document_causal(b, h, q_idx, kv_idx): causal_mask = q_idx >= kv_idx document_mask = docs[q_idx] == docs[kv_idx] return causal_mask & document_mask return create_block_mask(document_causal, B=1, H=None, Q_LEN=L, KV_LEN=L, BLOCK_SIZE=BLOCK_SIZE) # we assign prefill to the cache, similar to assign(), except we don't return the k_cache, v_cache, we only return the k_val, v_val def assign_prefill_no_paging( self, batch_idx: torch.Tensor, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, ) -> None: """ assigns kv and returns the original kv batch_idx: [1, L] input_pos: [1, L] k_val: [1, H, L, D] v_val: [1, H, L, D] k_cache: [1, H, MAX_S, D] v_cache: [1, H, MAX_S, D] """ assert batch_idx.ndim == 2, "batch_idx must be a 2D tensor" assert input_pos.ndim == 2, "input_pos must be a 2D tensor" assert k_val.ndim == 4, "k_val must be a 4D tensor" assert v_val.ndim == 4, "v_val must be a 4D tensor" assert k_cache.ndim == 4, "k_cache must be a 4D tensor" assert v_cache.ndim == 4, "v_cache must be a 4D tensor" assert batch_idx.shape[0] == 1, "batch_idx must have batch size 1" input_pos_block_idx = input_pos // self.page_size input_pos_offset_in_block = input_pos % self.page_size physical_kv_idx = self.page_table[batch_idx, input_pos_block_idx] * self.page_size + input_pos_offset_in_block k_cache[:, :, physical_kv_idx.view(-1), :] = k_val v_cache[:, :, physical_kv_idx.view(-1), :] = v_val return k_val, v_val ================================================ FILE: plot_metrics.py ================================================ # /// script # requires-python = ">=3.12" # dependencies = [ # "pandas", # "matplotlib", # ] # /// import pandas as pd import matplotlib.pyplot as plt # Read the CSV files flex_nano_df = pd.read_csv('flex_nano_vllm_metrics.csv') vllm_df = pd.read_csv('vllm_metrics.csv') # Create figure with subplots fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) # Plot 1: Running requests comparison ax1.plot(flex_nano_df['step'], flex_nano_df['requests_running'], label='Flex Nano VLLM', color='blue', linewidth=1.5) ax1.plot(vllm_df['steps'], vllm_df['requests_running'], label='VLLM', color='red', linewidth=1.5) ax1.set_title('Running Requests Over Time') ax1.set_xlabel('Step') ax1.set_ylabel('Running Requests') ax1.legend() ax1.grid(True, alpha=0.3) # Plot 2: Waiting requests comparison ax2.plot(flex_nano_df['step'], flex_nano_df['requests_waiting'], label='Flex Nano VLLM', color='blue', linewidth=1.5) ax2.plot(vllm_df['steps'], vllm_df['requests_waiting'], label='VLLM', color='red', linewidth=1.5) ax2.set_title('Waiting Requests Over Time') ax2.set_xlabel('Step') ax2.set_ylabel('Waiting Requests') ax2.legend() ax2.grid(True, alpha=0.3) # Plot 3: Flex Nano VLLM step types prefill_steps = flex_nano_df[flex_nano_df['step_type'] == 'prefill'] decode_steps = flex_nano_df[flex_nano_df['step_type'] == 'decode'] ax3.scatter(prefill_steps['step'], prefill_steps['requests_running'], label='Prefill', alpha=0.6, s=10, color='green') ax3.scatter(decode_steps['step'], decode_steps['requests_running'], label='Decode', alpha=0.6, s=10, color='orange') ax3.set_title('Flex Nano VLLM: Running Requests by Step Type') ax3.set_xlabel('Step') ax3.set_ylabel('Running Requests') ax3.legend() ax3.grid(True, alpha=0.3) # Plot 4: Total requests (running + waiting) flex_nano_total = flex_nano_df['requests_running'] + flex_nano_df['requests_waiting'] vllm_total = vllm_df['requests_running'] + vllm_df['requests_waiting'] ax4.plot(flex_nano_df['step'], flex_nano_total, label='Flex Nano VLLM Total', color='blue', linewidth=1.5) ax4.plot(vllm_df['steps'], vllm_total, label='VLLM Total', color='red', linewidth=1.5) ax4.set_title('Total Requests (Running + Waiting)') ax4.set_xlabel('Step') ax4.set_ylabel('Total Requests') ax4.legend() ax4.grid(True, alpha=0.3) plt.tight_layout() plt.savefig('metrics_comparison.png', dpi=300, bbox_inches='tight') print("Metrics comparison saved as 'metrics_comparison.png'") plt.show() # Print some summary statistics print("\n=== Summary Statistics ===") print(f"Flex Nano VLLM - Max running: {flex_nano_df['requests_running'].max()}") print(f"VLLM - Max running: {vllm_df['requests_running'].max()}") print(f"Flex Nano VLLM - Max waiting: {flex_nano_df['requests_waiting'].max()}") print(f"VLLM - Max waiting: {vllm_df['requests_waiting'].max()}") print(f"Flex Nano VLLM - Total steps: {len(flex_nano_df)}") print(f"VLLM - Total steps: {len(vllm_df)}") ================================================ FILE: pyproject.toml ================================================ [project] name = "flex-nano-vllm" version = "0.1.0" description = "Flex-attention based nano-vllm implementation for fast PaliGemma inference" readme = "README.md" requires-python = ">=3.10" dependencies = [ "accelerate>=1.9.0", "datasets>=3.0.0", "hf-transfer>=0.1.9", "matplotlib>=3.10.3", "torch>=2.7.1", "tqdm>=4.67.1", "transformers>=4.53.2", "triton>=3.3.1", ] [build-system] requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" [tool.setuptools] packages = ["flex_nano_vllm"] [dependency-groups] dev = [ "rich>=14.1.0", ] [tool.uv.sources] transformers = { git = "https://github.com/huggingface/transformers", rev = "34133d0a" } ================================================ FILE: visualize.py ================================================ # /// script # requires-python = ">=3.12" # dependencies = [ # "matplotlib", # "numpy", # ] # /// import matplotlib.pyplot as plt import numpy as np # Data configs = ['50% GPU', '90% GPU', '90% GPU\n(high batch)'] vllm_output = [3020, 3772, 3840] flex_output = [2313, 3076, 3440] # Create figure fig, ax = plt.subplots(figsize=(12, 8)) x = np.arange(len(configs)) width = 0.35 bars1 = ax.bar(x - width/2, vllm_output, width, label='vLLM v1', color='#1f77b4', alpha=0.8) bars2 = ax.bar(x + width/2, flex_output, width, label='flex-nano-vllm', color='#ff7f0e', alpha=0.8) ax.set_title('Output Tokens/s Comparison by Configuration', fontsize=16, fontweight='bold', pad=20) ax.set_ylabel('Tokens/s', fontsize=14) ax.set_xlabel('GPU Memory Configuration', fontsize=14) ax.set_xticks(x) ax.set_xticklabels(configs, fontsize=12) ax.legend(fontsize=12) ax.grid(axis='y', alpha=0.3) # Add value labels for bar in bars1: height = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., height + 50, f'{int(height)}', ha='center', va='bottom', fontweight='bold', fontsize=11) # Add value labels with percentages for flex-nano-vllm for i, bar in enumerate(bars2): height = bar.get_height() percentage = (flex_output[i] / vllm_output[i]) * 100 ax.text(bar.get_x() + bar.get_width()/2., height + 50, f'{int(height)}\n({percentage:.1f}%)', ha='center', va='bottom', fontweight='bold', fontsize=11) plt.tight_layout() # Save the plot plt.savefig('tokens_per_second_comparison.png', dpi=300, bbox_inches='tight') print("Simple comparison saved as 'tokens_per_second_comparison.png'") plt.show()