[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n#poetry.toml\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/\n\n# Abstra\n# Abstra is an AI-powered process automation framework.\n# Ignore directories containing user credentials, local state, and settings.\n# Learn more at https://abstra.io/docs\n.abstra/\n\n# Visual Studio Code\n#  Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore \n#  that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore\n#  and can be added to the global gitignore or merged into this file. However, if you prefer, \n#  you could uncomment the following to ignore the entire vscode folder\n.vscode/\n\n# Ruff stuff:\n.ruff_cache/\n\n# PyPI configuration file\n.pypirc\n\n# Cursor\n#  Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to\n#  exclude from AI features like autocomplete and code analysis. Recommended for sensitive data\n#  refer to https://docs.cursor.com/context/ignore-files\n.cursorignore\n.cursorindexingignore"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2025 Xingkai Yu\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n<img width=\"300\" src=\"assets/logo.png\">\n</p>\n\n<p align=\"center\">\n<a href=\"https://trendshift.io/repositories/15323\" target=\"_blank\"><img src=\"https://trendshift.io/api/badge/repositories/15323\" alt=\"GeeeekExplorer%2Fnano-vllm | Trendshift\" style=\"width: 250px; height: 55px;\" width=\"250\" height=\"55\"/></a>\n</p>\n\n# Nano-vLLM\n\nA lightweight vLLM implementation built from scratch.\n\n## Key Features\n\n* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM\n* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code\n* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.\n\n## Installation\n\n```bash\npip install git+https://github.com/GeeeekExplorer/nano-vllm.git\n```\n\n## Model Download\n\nTo download the model weights manually, use the following command:\n```bash\nhuggingface-cli download --resume-download Qwen/Qwen3-0.6B \\\n  --local-dir ~/huggingface/Qwen3-0.6B/ \\\n  --local-dir-use-symlinks False\n```\n\n## Quick Start\n\nSee `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:\n```python\nfrom nanovllm import LLM, SamplingParams\nllm = LLM(\"/YOUR/MODEL/PATH\", enforce_eager=True, tensor_parallel_size=1)\nsampling_params = SamplingParams(temperature=0.6, max_tokens=256)\nprompts = [\"Hello, Nano-vLLM.\"]\noutputs = llm.generate(prompts, sampling_params)\noutputs[0][\"text\"]\n```\n\n## Benchmark\n\nSee `bench.py` for benchmark.\n\n**Test Configuration:**\n- Hardware: RTX 4070 Laptop (8GB)\n- Model: Qwen3-0.6B\n- Total Requests: 256 sequences\n- Input Length: Randomly sampled between 100–1024 tokens\n- Output Length: Randomly sampled between 100–1024 tokens\n\n**Performance Results:**\n| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |\n|----------------|-------------|----------|-----------------------|\n| vLLM           | 133,966     | 98.37    | 1361.84               |\n| Nano-vLLM      | 133,966     | 93.41    | 1434.13               |\n\n\n## Star History\n\n[![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)"
  },
  {
    "path": "bench.py",
    "content": "import os\nimport time\nfrom random import randint, seed\nfrom nanovllm import LLM, SamplingParams\n# from vllm import LLM, SamplingParams\n\n\ndef main():\n    seed(0)\n    num_seqs = 256\n    max_input_len = 1024\n    max_ouput_len = 1024\n\n    path = os.path.expanduser(\"~/huggingface/Qwen3-0.6B/\")\n    llm = LLM(path, enforce_eager=False, max_model_len=4096)\n\n    prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]\n    sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]\n    # uncomment the following line for vllm\n    # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]\n\n    llm.generate([\"Benchmark: \"], SamplingParams())\n    t = time.time()\n    llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)\n    t = (time.time() - t)\n    total_tokens = sum(sp.max_tokens for sp in sampling_params)\n    throughput = total_tokens / t\n    print(f\"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "example.py",
    "content": "import os\nfrom nanovllm import LLM, SamplingParams\nfrom transformers import AutoTokenizer\n\n\ndef main():\n    path = os.path.expanduser(\"~/huggingface/Qwen3-0.6B/\")\n    tokenizer = AutoTokenizer.from_pretrained(path)\n    llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)\n\n    sampling_params = SamplingParams(temperature=0.6, max_tokens=256)\n    prompts = [\n        \"introduce yourself\",\n        \"list all prime numbers within 100\",\n    ]\n    prompts = [\n        tokenizer.apply_chat_template(\n            [{\"role\": \"user\", \"content\": prompt}],\n            tokenize=False,\n            add_generation_prompt=True,\n        )\n        for prompt in prompts\n    ]\n    outputs = llm.generate(prompts, sampling_params)\n\n    for prompt, output in zip(prompts, outputs):\n        print(\"\\n\")\n        print(f\"Prompt: {prompt!r}\")\n        print(f\"Completion: {output['text']!r}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "nanovllm/__init__.py",
    "content": "from nanovllm.llm import LLM\nfrom nanovllm.sampling_params import SamplingParams\n"
  },
  {
    "path": "nanovllm/config.py",
    "content": "import os\nfrom dataclasses import dataclass\nfrom transformers import AutoConfig\n\n\n@dataclass\nclass Config:\n    model: str\n    max_num_batched_tokens: int = 16384\n    max_num_seqs: int = 512\n    max_model_len: int = 4096\n    gpu_memory_utilization: float = 0.9\n    tensor_parallel_size: int = 1\n    enforce_eager: bool = False\n    hf_config: AutoConfig | None = None\n    eos: int = -1\n    kvcache_block_size: int = 256\n    num_kvcache_blocks: int = -1\n\n    def __post_init__(self):\n        assert os.path.isdir(self.model)\n        assert self.kvcache_block_size % 256 == 0\n        assert 1 <= self.tensor_parallel_size <= 8\n        self.hf_config = AutoConfig.from_pretrained(self.model)\n        self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)\n        assert self.max_num_batched_tokens >= self.max_model_len\n"
  },
  {
    "path": "nanovllm/engine/block_manager.py",
    "content": "from collections import deque\nimport xxhash\nimport numpy as np\n\nfrom nanovllm.engine.sequence import Sequence\n\n\nclass Block:\n\n    def __init__(self, block_id):\n        self.block_id = block_id\n        self.ref_count = 0\n        self.hash = -1\n        self.token_ids = []\n\n    def update(self, hash: int, token_ids: list[int]):\n        self.hash = hash\n        self.token_ids = token_ids\n\n    def reset(self):\n        self.ref_count = 1\n        self.hash = -1\n        self.token_ids = []\n\n\nclass BlockManager:\n\n    def __init__(self, num_blocks: int, block_size: int):\n        self.block_size = block_size\n        self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]\n        self.hash_to_block_id: dict[int, int] = dict()\n        self.free_block_ids: deque[int] = deque(range(num_blocks))\n        self.used_block_ids: set[int] = set()\n\n    @classmethod\n    def compute_hash(cls, token_ids: list[int], prefix: int = -1):\n        h = xxhash.xxh64()\n        if prefix != -1:\n            h.update(prefix.to_bytes(8, \"little\"))\n        h.update(np.array(token_ids).tobytes())\n        return h.intdigest()\n\n    def _allocate_block(self, block_id: int) -> Block:\n        block = self.blocks[block_id]\n        assert block.ref_count == 0\n        block.reset()\n        self.free_block_ids.remove(block_id)\n        self.used_block_ids.add(block_id)\n        return self.blocks[block_id]\n\n    def _deallocate_block(self, block_id: int) -> Block:\n        assert self.blocks[block_id].ref_count == 0\n        self.used_block_ids.remove(block_id)\n        self.free_block_ids.append(block_id)\n\n    def can_allocate(self, seq: Sequence) -> bool:\n        return len(self.free_block_ids) >= seq.num_blocks\n\n    def allocate(self, seq: Sequence):\n        assert not seq.block_table\n        h = -1\n        cache_miss = False\n        for i in range(seq.num_blocks):\n            token_ids = seq.block(i)\n            h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1\n            block_id = self.hash_to_block_id.get(h, -1)\n            if block_id == -1 or self.blocks[block_id].token_ids != token_ids:\n                cache_miss = True\n            if cache_miss:\n                block_id = self.free_block_ids[0]\n                block = self._allocate_block(block_id)\n            else:\n                seq.num_cached_tokens += self.block_size\n                if block_id in self.used_block_ids:\n                    block = self.blocks[block_id]\n                    block.ref_count += 1\n                else:\n                    block = self._allocate_block(block_id)\n            if h != -1:\n                block.update(h, token_ids)\n                self.hash_to_block_id[h] = block_id\n            seq.block_table.append(block_id)\n\n    def deallocate(self, seq: Sequence):\n        for block_id in reversed(seq.block_table):\n            block = self.blocks[block_id]\n            block.ref_count -= 1\n            if block.ref_count == 0:\n                self._deallocate_block(block_id)\n        seq.num_cached_tokens = 0\n        seq.block_table.clear()\n\n    def can_append(self, seq: Sequence) -> bool:\n        return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)\n\n    def may_append(self, seq: Sequence):\n        block_table = seq.block_table\n        last_block = self.blocks[block_table[-1]]\n        if len(seq) % self.block_size == 1:\n            assert last_block.hash != -1\n            block_id = self.free_block_ids[0]\n            self._allocate_block(block_id)\n            block_table.append(block_id)\n        elif len(seq) % self.block_size == 0:\n            assert last_block.hash == -1\n            token_ids = seq.block(seq.num_blocks-1)\n            prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1\n            h = self.compute_hash(token_ids, prefix)\n            last_block.update(h, token_ids)\n            self.hash_to_block_id[h] = last_block.block_id\n        else:\n            assert last_block.hash == -1\n"
  },
  {
    "path": "nanovllm/engine/llm_engine.py",
    "content": "import atexit\nfrom dataclasses import fields\nfrom time import perf_counter\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer\nimport torch.multiprocessing as mp\n\nfrom nanovllm.config import Config\nfrom nanovllm.sampling_params import SamplingParams\nfrom nanovllm.engine.sequence import Sequence\nfrom nanovllm.engine.scheduler import Scheduler\nfrom nanovllm.engine.model_runner import ModelRunner\n\n\nclass LLMEngine:\n\n    def __init__(self, model, **kwargs):\n        config_fields = {field.name for field in fields(Config)}\n        config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}\n        config = Config(model, **config_kwargs)\n        self.ps = []\n        self.events = []\n        ctx = mp.get_context(\"spawn\")\n        for i in range(1, config.tensor_parallel_size):\n            event = ctx.Event()\n            process = ctx.Process(target=ModelRunner, args=(config, i, event))\n            process.start()\n            self.ps.append(process)\n            self.events.append(event)\n        self.model_runner = ModelRunner(config, 0, self.events)\n        self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)\n        config.eos = self.tokenizer.eos_token_id\n        self.scheduler = Scheduler(config)\n        atexit.register(self.exit)\n\n    def exit(self):\n        self.model_runner.call(\"exit\")\n        del self.model_runner\n        for p in self.ps:\n            p.join()\n\n    def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):\n        if isinstance(prompt, str):\n            prompt = self.tokenizer.encode(prompt)\n        seq = Sequence(prompt, sampling_params)\n        self.scheduler.add(seq)\n\n    def step(self):\n        seqs, is_prefill = self.scheduler.schedule()\n        token_ids = self.model_runner.call(\"run\", seqs, is_prefill)\n        self.scheduler.postprocess(seqs, token_ids)\n        outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]\n        num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)\n        return outputs, num_tokens\n\n    def is_finished(self):\n        return self.scheduler.is_finished()\n\n    def generate(\n        self,\n        prompts: list[str] | list[list[int]],\n        sampling_params: SamplingParams | list[SamplingParams],\n        use_tqdm: bool = True,\n    ) -> list[str]:\n        if use_tqdm:\n            pbar = tqdm(total=len(prompts), desc=\"Generating\", dynamic_ncols=True)\n        if not isinstance(sampling_params, list):\n            sampling_params = [sampling_params] * len(prompts)\n        for prompt, sp in zip(prompts, sampling_params):\n            self.add_request(prompt, sp)\n        outputs = {}\n        prefill_throughput = decode_throughput = 0.\n        while not self.is_finished():\n            t = perf_counter()\n            output, num_tokens = self.step()\n            if use_tqdm:\n                if num_tokens > 0:\n                    prefill_throughput = num_tokens / (perf_counter() - t)\n                else:\n                    decode_throughput = -num_tokens / (perf_counter() - t)\n                pbar.set_postfix({\n                    \"Prefill\": f\"{int(prefill_throughput)}tok/s\",\n                    \"Decode\": f\"{int(decode_throughput)}tok/s\",\n                })\n            for seq_id, token_ids in output:\n                outputs[seq_id] = token_ids\n                if use_tqdm:\n                    pbar.update(1)\n        outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]\n        outputs = [{\"text\": self.tokenizer.decode(token_ids), \"token_ids\": token_ids} for token_ids in outputs]\n        if use_tqdm:\n            pbar.close()\n        return outputs\n"
  },
  {
    "path": "nanovllm/engine/model_runner.py",
    "content": "import pickle\nimport torch\nimport torch.distributed as dist\nfrom multiprocessing.synchronize import Event\nfrom multiprocessing.shared_memory import SharedMemory\n\nfrom nanovllm.config import Config\nfrom nanovllm.engine.sequence import Sequence\nfrom nanovllm.models.qwen3 import Qwen3ForCausalLM\nfrom nanovllm.layers.sampler import Sampler\nfrom nanovllm.utils.context import set_context, get_context, reset_context\nfrom nanovllm.utils.loader import load_model\n\n\nclass ModelRunner:\n\n    def __init__(self, config: Config, rank: int, event: Event | list[Event]):\n        self.config = config\n        hf_config = config.hf_config\n        self.block_size = config.kvcache_block_size\n        self.enforce_eager = config.enforce_eager\n        self.world_size = config.tensor_parallel_size\n        self.rank = rank\n        self.event = event\n\n        dist.init_process_group(\"nccl\", \"tcp://localhost:2333\", world_size=self.world_size, rank=rank)\n        torch.cuda.set_device(rank)\n        default_dtype = torch.get_default_dtype()\n        torch.set_default_dtype(hf_config.torch_dtype)\n        torch.set_default_device(\"cuda\")\n        self.model = Qwen3ForCausalLM(hf_config)\n        load_model(self.model, config.model)\n        self.sampler = Sampler()\n        self.warmup_model()\n        self.allocate_kv_cache()\n        if not self.enforce_eager:\n            self.capture_cudagraph()\n        torch.set_default_device(\"cpu\")\n        torch.set_default_dtype(default_dtype)\n\n        if self.world_size > 1:\n            if rank == 0:\n                self.shm = SharedMemory(name=\"nanovllm\", create=True, size=2**20)\n                dist.barrier()\n            else:\n                dist.barrier()\n                self.shm = SharedMemory(name=\"nanovllm\")\n                self.loop()\n\n    def exit(self):\n        if self.world_size > 1:\n            self.shm.close()\n            dist.barrier()\n            if self.rank == 0:\n                self.shm.unlink()\n        if not self.enforce_eager:\n            del self.graphs, self.graph_pool\n        torch.cuda.synchronize()\n        dist.destroy_process_group()\n\n    def loop(self):\n        while True:\n            method_name, args = self.read_shm()\n            self.call(method_name, *args)\n            if method_name == \"exit\":\n                break\n\n    def read_shm(self):\n        assert self.world_size > 1 and self.rank > 0\n        self.event.wait()\n        n = int.from_bytes(self.shm.buf[0:4], \"little\")\n        method_name, *args = pickle.loads(self.shm.buf[4:n+4])\n        self.event.clear()\n        return method_name, args\n\n    def write_shm(self, method_name, *args):\n        assert self.world_size > 1 and self.rank == 0\n        data = pickle.dumps([method_name, *args])\n        n = len(data)\n        self.shm.buf[0:4] = n.to_bytes(4, \"little\")\n        self.shm.buf[4:n+4] = data\n        for event in self.event:\n            event.set()\n\n    def call(self, method_name, *args):\n        if self.world_size > 1 and self.rank == 0:\n            self.write_shm(method_name, *args)\n        method = getattr(self, method_name, None)\n        return method(*args)\n\n    def warmup_model(self):\n        torch.cuda.empty_cache()\n        torch.cuda.reset_peak_memory_stats()\n        max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len\n        num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)\n        seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]\n        self.run(seqs, True)\n        torch.cuda.empty_cache()\n\n    def allocate_kv_cache(self):\n        config = self.config\n        hf_config = config.hf_config\n        free, total = torch.cuda.mem_get_info()\n        used = total - free\n        peak = torch.cuda.memory_stats()[\"allocated_bytes.all.peak\"]\n        current = torch.cuda.memory_stats()[\"allocated_bytes.all.current\"]\n        num_kv_heads = hf_config.num_key_value_heads // self.world_size\n        head_dim = getattr(hf_config, \"head_dim\", hf_config.hidden_size // hf_config.num_attention_heads)\n        block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize\n        config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes\n        assert config.num_kvcache_blocks > 0\n        self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)\n        layer_id = 0\n        for module in self.model.modules():\n            if hasattr(module, \"k_cache\") and hasattr(module, \"v_cache\"):\n                module.k_cache = self.kv_cache[0, layer_id]\n                module.v_cache = self.kv_cache[1, layer_id]\n                layer_id += 1\n\n    def prepare_block_tables(self, seqs: list[Sequence]):\n        max_len = max(len(seq.block_table) for seq in seqs)\n        block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]\n        block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)\n        return block_tables\n\n    def prepare_prefill(self, seqs: list[Sequence]):\n        input_ids = []\n        positions = []\n        cu_seqlens_q = [0]\n        cu_seqlens_k = [0]\n        max_seqlen_q = 0\n        max_seqlen_k = 0\n        slot_mapping = []\n        block_tables = None\n        for seq in seqs:\n            seqlen = len(seq)\n            input_ids.extend(seq[seq.num_cached_tokens:])\n            positions.extend(list(range(seq.num_cached_tokens, seqlen)))\n            seqlen_q = seqlen - seq.num_cached_tokens\n            seqlen_k = seqlen\n            cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)\n            cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)\n            max_seqlen_q = max(seqlen_q, max_seqlen_q)\n            max_seqlen_k = max(seqlen_k, max_seqlen_k)\n            if not seq.block_table:    # warmup\n                continue\n            for i in range(seq.num_cached_blocks, seq.num_blocks):\n                start = seq.block_table[i] * self.block_size\n                if i != seq.num_blocks - 1:\n                    end = start + self.block_size\n                else:\n                    end = start + seq.last_block_num_tokens \n                slot_mapping.extend(list(range(start, end)))\n        if cu_seqlens_k[-1] > cu_seqlens_q[-1]:    # prefix cache\n            block_tables = self.prepare_block_tables(seqs)\n        input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)\n        positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)\n        cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)\n        cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)\n        slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)\n        set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)\n        return input_ids, positions\n\n    def prepare_decode(self, seqs: list[Sequence]):\n        input_ids = []\n        positions = []\n        slot_mapping = []\n        context_lens = []\n        for seq in seqs:\n            input_ids.append(seq.last_token)\n            positions.append(len(seq) - 1)\n            context_lens.append(len(seq))\n            slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens  - 1)\n        input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)\n        positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)\n        slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)\n        context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)\n        block_tables = self.prepare_block_tables(seqs)\n        set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)\n        return input_ids, positions\n\n    def prepare_sample(self, seqs: list[Sequence]):\n        temperatures = []\n        for seq in seqs:\n            temperatures.append(seq.temperature)\n        temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)\n        return temperatures\n\n    @torch.inference_mode()\n    def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):\n        if is_prefill or self.enforce_eager or input_ids.size(0) > 512:\n            return self.model.compute_logits(self.model(input_ids, positions))\n        else:\n            bs = input_ids.size(0)\n            context = get_context()\n            graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]\n            graph_vars = self.graph_vars\n            graph_vars[\"input_ids\"][:bs] = input_ids\n            graph_vars[\"positions\"][:bs] = positions\n            graph_vars[\"slot_mapping\"].fill_(-1)\n            graph_vars[\"slot_mapping\"][:bs] = context.slot_mapping\n            graph_vars[\"context_lens\"].zero_()\n            graph_vars[\"context_lens\"][:bs] = context.context_lens\n            graph_vars[\"block_tables\"][:bs, :context.block_tables.size(1)] = context.block_tables\n            graph.replay()\n            return self.model.compute_logits(graph_vars[\"outputs\"][:bs])\n\n    def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:\n        input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)\n        temperatures = self.prepare_sample(seqs) if self.rank == 0 else None\n        logits = self.run_model(input_ids, positions, is_prefill)\n        token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None\n        reset_context()\n        return token_ids\n\n    @torch.inference_mode()\n    def capture_cudagraph(self):\n        config = self.config\n        hf_config = config.hf_config\n        max_bs = min(self.config.max_num_seqs, 512)\n        max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size\n        input_ids = torch.zeros(max_bs, dtype=torch.int64)\n        positions = torch.zeros(max_bs, dtype=torch.int64)\n        slot_mapping = torch.zeros(max_bs, dtype=torch.int32)\n        context_lens = torch.zeros(max_bs, dtype=torch.int32)\n        block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)\n        outputs = torch.zeros(max_bs, hf_config.hidden_size)\n        self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))\n        self.graphs = {}\n        self.graph_pool = None\n\n        for bs in reversed(self.graph_bs):\n            graph = torch.cuda.CUDAGraph()\n            set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])\n            outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # warmup\n            with torch.cuda.graph(graph, self.graph_pool):\n                outputs[:bs] = self.model(input_ids[:bs], positions[:bs])    # capture\n            if self.graph_pool is None:\n                self.graph_pool = graph.pool()\n            self.graphs[bs] = graph\n            torch.cuda.synchronize()\n            reset_context()\n\n        self.graph_vars = dict(\n            input_ids=input_ids,\n            positions=positions,\n            slot_mapping=slot_mapping,\n            context_lens=context_lens,\n            block_tables=block_tables,\n            outputs=outputs,\n        )\n"
  },
  {
    "path": "nanovllm/engine/scheduler.py",
    "content": "from collections import deque\n\nfrom nanovllm.config import Config\nfrom nanovllm.engine.sequence import Sequence, SequenceStatus\nfrom nanovllm.engine.block_manager import BlockManager\n\n\nclass Scheduler:\n\n    def __init__(self, config: Config):\n        self.max_num_seqs = config.max_num_seqs\n        self.max_num_batched_tokens = config.max_num_batched_tokens\n        self.eos = config.eos\n        self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)\n        self.waiting: deque[Sequence] = deque()\n        self.running: deque[Sequence] = deque()\n\n    def is_finished(self):\n        return not self.waiting and not self.running\n\n    def add(self, seq: Sequence):\n        self.waiting.append(seq)\n\n    def schedule(self) -> tuple[list[Sequence], bool]:\n        # prefill\n        scheduled_seqs = []\n        num_seqs = 0\n        num_batched_tokens = 0\n        while self.waiting and num_seqs < self.max_num_seqs:\n            seq = self.waiting[0]\n            if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):\n                break\n            num_seqs += 1\n            self.block_manager.allocate(seq)\n            num_batched_tokens += len(seq) - seq.num_cached_tokens\n            seq.status = SequenceStatus.RUNNING\n            self.waiting.popleft()\n            self.running.append(seq)\n            scheduled_seqs.append(seq)\n        if scheduled_seqs:\n            return scheduled_seqs, True\n\n        # decode\n        while self.running and num_seqs < self.max_num_seqs:\n            seq = self.running.popleft()\n            while not self.block_manager.can_append(seq):\n                if self.running:\n                    self.preempt(self.running.pop())\n                else:\n                    self.preempt(seq)\n                    break\n            else:\n                num_seqs += 1\n                self.block_manager.may_append(seq)\n                scheduled_seqs.append(seq)\n        assert scheduled_seqs\n        self.running.extendleft(reversed(scheduled_seqs))\n        return scheduled_seqs, False\n\n    def preempt(self, seq: Sequence):\n        seq.status = SequenceStatus.WAITING\n        self.block_manager.deallocate(seq)\n        self.waiting.appendleft(seq)\n\n    def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:\n        for seq, token_id in zip(seqs, token_ids):\n            seq.append_token(token_id)\n            if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:\n                seq.status = SequenceStatus.FINISHED\n                self.block_manager.deallocate(seq)\n                self.running.remove(seq)\n"
  },
  {
    "path": "nanovllm/engine/sequence.py",
    "content": "from copy import copy\nfrom enum import Enum, auto\nfrom itertools import count\n\nfrom nanovllm.sampling_params import SamplingParams\n\n\nclass SequenceStatus(Enum):\n    WAITING = auto()\n    RUNNING = auto()\n    FINISHED = auto()\n\n\nclass Sequence:\n    block_size = 256\n    counter = count()\n\n    def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):\n        self.seq_id = next(Sequence.counter)\n        self.status = SequenceStatus.WAITING\n        self.token_ids = copy(token_ids)\n        self.last_token = token_ids[-1]\n        self.num_tokens = len(self.token_ids)\n        self.num_prompt_tokens = len(token_ids)\n        self.num_cached_tokens = 0\n        self.block_table = []\n        self.temperature = sampling_params.temperature\n        self.max_tokens = sampling_params.max_tokens\n        self.ignore_eos = sampling_params.ignore_eos\n\n    def __len__(self):\n        return self.num_tokens\n\n    def __getitem__(self, key):\n        return self.token_ids[key]\n\n    @property\n    def is_finished(self):\n        return self.status == SequenceStatus.FINISHED\n\n    @property\n    def num_completion_tokens(self):\n        return self.num_tokens - self.num_prompt_tokens\n\n    @property\n    def prompt_token_ids(self):\n        return self.token_ids[:self.num_prompt_tokens]\n\n    @property\n    def completion_token_ids(self):\n        return self.token_ids[self.num_prompt_tokens:]\n\n    @property\n    def num_cached_blocks(self):\n        return self.num_cached_tokens // self.block_size\n\n    @property\n    def num_blocks(self):\n        return (self.num_tokens + self.block_size - 1) // self.block_size\n\n    @property\n    def last_block_num_tokens(self):\n        return self.num_tokens - (self.num_blocks - 1) * self.block_size\n\n    def block(self, i):\n        assert 0 <= i < self.num_blocks\n        return self.token_ids[i*self.block_size: (i+1)*self.block_size]\n\n    def append_token(self, token_id: int):\n        self.token_ids.append(token_id)\n        self.last_token = token_id\n        self.num_tokens += 1\n\n    def __getstate__(self):\n        return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,\n                self.token_ids if self.num_completion_tokens == 0 else self.last_token)\n\n    def __setstate__(self, state):\n        self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]\n        if self.num_completion_tokens == 0:\n            self.token_ids = state[-1]\n        else:\n            self.last_token = state[-1]\n"
  },
  {
    "path": "nanovllm/layers/activation.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\nclass SiluAndMul(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n\n    @torch.compile\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x, y = x.chunk(2, -1)\n        return F.silu(x) * y\n"
  },
  {
    "path": "nanovllm/layers/attention.py",
    "content": "import torch\nfrom torch import nn\nimport triton\nimport triton.language as tl\n\nfrom flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache\nfrom nanovllm.utils.context import get_context\n\n\n@triton.jit\ndef store_kvcache_kernel(\n    key_ptr,\n    key_stride,\n    value_ptr,\n    value_stride,\n    k_cache_ptr,\n    v_cache_ptr,\n    slot_mapping_ptr,\n    D: tl.constexpr,\n):\n    idx = tl.program_id(0)\n    slot = tl.load(slot_mapping_ptr + idx)\n    if slot == -1: return\n    key_offsets = idx * key_stride + tl.arange(0, D)\n    value_offsets = idx * value_stride + tl.arange(0, D)\n    key = tl.load(key_ptr + key_offsets)\n    value = tl.load(value_ptr + value_offsets)\n    cache_offsets = slot * D + tl.arange(0, D)\n    tl.store(k_cache_ptr + cache_offsets, key)\n    tl.store(v_cache_ptr + cache_offsets, value)\n\n\ndef store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):\n    N, num_heads, head_dim = key.shape\n    D = num_heads * head_dim\n    assert key.stride(-1) == 1 and value.stride(-1) == 1\n    assert key.stride(1) == head_dim and value.stride(1) == head_dim\n    assert k_cache.stride(1) == D and v_cache.stride(1) == D\n    assert slot_mapping.numel() == N\n    store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)\n\n\nclass Attention(nn.Module):\n\n    def __init__(\n        self,\n        num_heads,\n        head_dim,\n        scale,\n        num_kv_heads,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.scale = scale\n        self.num_kv_heads = num_kv_heads\n        self.k_cache = self.v_cache = torch.tensor([])\n\n    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):\n        context = get_context()\n        k_cache, v_cache = self.k_cache, self.v_cache\n        if k_cache.numel() and v_cache.numel():\n            store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)\n        if context.is_prefill:\n            if context.block_tables is not None:    # prefix cache\n                k, v = k_cache, v_cache\n            o = flash_attn_varlen_func(q, k, v,\n                                       max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,\n                                       max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,\n                                       softmax_scale=self.scale, causal=True, block_table=context.block_tables)\n        else:    # decode\n            o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,\n                                        cache_seqlens=context.context_lens, block_table=context.block_tables, \n                                        softmax_scale=self.scale, causal=True)\n        return o\n"
  },
  {
    "path": "nanovllm/layers/embed_head.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport torch.distributed as dist\n\nfrom nanovllm.utils.context import get_context\n\n\nclass VocabParallelEmbedding(nn.Module):\n\n    def __init__(\n        self,\n        num_embeddings: int,\n        embedding_dim: int,\n    ):\n        super().__init__()\n        self.tp_rank = dist.get_rank()\n        self.tp_size = dist.get_world_size()\n        assert num_embeddings % self.tp_size == 0\n        self.num_embeddings = num_embeddings\n        self.num_embeddings_per_partition = self.num_embeddings // self.tp_size\n        self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank\n        self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition\n        self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))\n        self.weight.weight_loader = self.weight_loader\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):\n        param_data = param.data\n        shard_size = param_data.size(0)\n        start_idx = self.tp_rank * shard_size\n        loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)\n        param_data.copy_(loaded_weight)\n\n    def forward(self, x: torch.Tensor):\n        if self.tp_size > 1:\n            mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)\n            x = mask * (x - self.vocab_start_idx)\n        y = F.embedding(x, self.weight)\n        if self.tp_size > 1:\n            y = mask.unsqueeze(1) * y\n            dist.all_reduce(y)\n        return y\n\n\nclass ParallelLMHead(VocabParallelEmbedding):\n\n    def __init__(\n        self,\n        num_embeddings: int,\n        embedding_dim: int,\n        bias: bool = False,\n    ):\n        assert not bias\n        super().__init__(num_embeddings, embedding_dim)\n\n    def forward(self, x: torch.Tensor):\n        context = get_context()\n        if context.is_prefill:\n            last_indices = context.cu_seqlens_q[1:] - 1\n            x = x[last_indices].contiguous()\n        logits = F.linear(x, self.weight)\n        if self.tp_size > 1:\n            all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None\n            dist.gather(logits, all_logits, 0)\n            logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None\n        return logits\n"
  },
  {
    "path": "nanovllm/layers/layernorm.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass RMSNorm(nn.Module):\n\n    def __init__(\n        self,\n        hidden_size: int,\n        eps: float = 1e-6,\n    ) -> None:\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n\n    @torch.compile\n    def rms_forward(\n        self,\n        x: torch.Tensor,\n    ) -> torch.Tensor:\n        orig_dtype = x.dtype\n        x = x.float()\n        var = x.pow(2).mean(dim=-1, keepdim=True)\n        x.mul_(torch.rsqrt(var + self.eps))\n        x = x.to(orig_dtype).mul_(self.weight)\n        return x\n\n    @torch.compile\n    def add_rms_forward(\n        self,\n        x: torch.Tensor,\n        residual: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        orig_dtype = x.dtype\n        x = x.float().add_(residual.float())\n        residual = x.to(orig_dtype)\n        var = x.pow(2).mean(dim=-1, keepdim=True)\n        x.mul_(torch.rsqrt(var + self.eps))\n        x = x.to(orig_dtype).mul_(self.weight)\n        return x, residual\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        residual: torch.Tensor | None = None,\n    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n        if residual is None:\n            return self.rms_forward(x)\n        else:\n            return self.add_rms_forward(x, residual)\n"
  },
  {
    "path": "nanovllm/layers/linear.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport torch.distributed as dist\n\n\ndef divide(numerator, denominator):\n    assert numerator % denominator == 0\n    return numerator // denominator\n\n\nclass LinearBase(nn.Module):\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = False,\n        tp_dim: int | None = None,\n    ):\n        super().__init__()\n        self.tp_dim = tp_dim\n        self.tp_rank = dist.get_rank()\n        self.tp_size = dist.get_world_size()\n        self.weight = nn.Parameter(torch.empty(output_size, input_size))\n        self.weight.weight_loader = self.weight_loader\n        if bias:\n            self.bias = nn.Parameter(torch.empty(output_size))\n            self.bias.weight_loader = self.weight_loader\n        else:\n            self.register_parameter(\"bias\", None)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError\n\n\nclass ReplicatedLinear(LinearBase):\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = False,\n    ):\n        super().__init__(input_size, output_size, bias)\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):\n        param.data.copy_(loaded_weight)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return F.linear(x, self.weight, self.bias)\n\n\nclass ColumnParallelLinear(LinearBase):\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = False,\n    ):\n        tp_size = dist.get_world_size()\n        super().__init__(input_size, divide(output_size, tp_size), bias, 0)\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):\n        param_data = param.data\n        shard_size = param_data.size(self.tp_dim)\n        start_idx = self.tp_rank * shard_size\n        loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)\n        param_data.copy_(loaded_weight)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return F.linear(x, self.weight, self.bias)\n\n\nclass MergedColumnParallelLinear(ColumnParallelLinear):\n\n    def __init__(\n        self,\n        input_size: int,\n        output_sizes: list[int],\n        bias: bool = False,\n    ):\n        self.output_sizes = output_sizes\n        super().__init__(input_size, sum(output_sizes), bias)\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):\n        param_data = param.data\n        shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size\n        shard_size = self.output_sizes[loaded_shard_id] // self.tp_size\n        param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)\n        loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]\n        param_data.copy_(loaded_weight)\n\n\nclass QKVParallelLinear(ColumnParallelLinear):\n\n    def __init__(\n        self,\n        hidden_size: int,\n        head_size: int,\n        total_num_heads: int,\n        total_num_kv_heads: int | None = None,\n        bias: bool = False,\n    ):\n        tp_size = dist.get_world_size()\n        total_num_kv_heads = total_num_kv_heads or total_num_heads\n        self.head_size = head_size\n        self.num_heads = divide(total_num_heads, tp_size)\n        self.num_kv_heads = divide(total_num_kv_heads, tp_size)\n        output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size\n        super().__init__(hidden_size, output_size, bias)\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):\n        param_data = param.data\n        assert loaded_shard_id in [\"q\", \"k\", \"v\"]\n        if loaded_shard_id == \"q\":\n            shard_size = self.num_heads * self.head_size\n            shard_offset = 0\n        elif loaded_shard_id == \"k\":\n            shard_size = self.num_kv_heads * self.head_size\n            shard_offset = self.num_heads * self.head_size\n        else:\n            shard_size = self.num_kv_heads * self.head_size\n            shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size\n        param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)\n        loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]\n        param_data.copy_(loaded_weight)\n\n\nclass RowParallelLinear(LinearBase):\n\n    def __init__(\n        self,\n        input_size: int,\n        output_size: int,\n        bias: bool = False,\n    ):\n        tp_size = dist.get_world_size()\n        super().__init__(divide(input_size, tp_size), output_size, bias, 1)\n\n    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):\n        param_data = param.data\n        shard_size = param_data.size(self.tp_dim)\n        start_idx = self.tp_rank * shard_size\n        loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)\n        param_data.copy_(loaded_weight)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)\n        if self.tp_size > 1:\n            dist.all_reduce(y)\n        return y\n"
  },
  {
    "path": "nanovllm/layers/rotary_embedding.py",
    "content": "from functools import lru_cache\nimport torch\nfrom torch import nn\n\n\ndef apply_rotary_emb(\n    x: torch.Tensor,\n    cos: torch.Tensor,\n    sin: torch.Tensor,\n) -> torch.Tensor:\n    x1, x2 = torch.chunk(x.float(), 2, dim=-1)\n    y1 = x1 * cos - x2 * sin\n    y2 = x2 * cos + x1 * sin\n    return torch.cat((y1, y2), dim=-1).to(x.dtype)\n\n\nclass RotaryEmbedding(nn.Module):\n\n    def __init__(\n        self,\n        head_size: int,\n        rotary_dim: int,\n        max_position_embeddings: int,\n        base: float,\n    ) -> None:\n        super().__init__()\n        self.head_size = head_size\n        assert rotary_dim == head_size\n        inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))\n        t = torch.arange(max_position_embeddings, dtype=torch.float)\n        freqs = torch.einsum(\"i,j -> ij\", t, inv_freq)\n        cos = freqs.cos()\n        sin = freqs.sin()\n        cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)\n        self.register_buffer(\"cos_sin_cache\", cache, persistent=False)\n\n    @torch.compile\n    def forward(\n        self,\n        positions: torch.Tensor,\n        query: torch.Tensor,\n        key: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        cos_sin = self.cos_sin_cache[positions]\n        cos, sin = cos_sin.chunk(2, dim=-1)\n        query = apply_rotary_emb(query, cos, sin)\n        key = apply_rotary_emb(key, cos, sin)\n        return query, key\n\n\n@lru_cache(1)\ndef get_rope(\n    head_size: int,\n    rotary_dim: int,\n    max_position: int,\n    base: float,\n    rope_scaling: dict | None = None,\n):\n    assert rope_scaling is None\n    rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)\n    return rotary_emb\n"
  },
  {
    "path": "nanovllm/layers/sampler.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass Sampler(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n\n    @torch.compile\n    def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):\n        logits = logits.float().div_(temperatures.unsqueeze(dim=1))\n        probs = torch.softmax(logits, dim=-1)\n        sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)\n        return sample_tokens\n"
  },
  {
    "path": "nanovllm/llm.py",
    "content": "from nanovllm.engine.llm_engine import LLMEngine\n\n\nclass LLM(LLMEngine):\n    pass\n"
  },
  {
    "path": "nanovllm/models/qwen3.py",
    "content": "import torch\nfrom torch import nn\nimport torch.distributed as dist\nfrom transformers import Qwen3Config\n\nfrom nanovllm.layers.activation import SiluAndMul\nfrom nanovllm.layers.attention import Attention\nfrom nanovllm.layers.layernorm import RMSNorm\nfrom nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear\nfrom nanovllm.layers.rotary_embedding import get_rope\nfrom nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead\n\n\nclass Qwen3Attention(nn.Module):\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        num_kv_heads: int,\n        max_position: int = 4096 * 32,\n        head_dim: int | None = None,\n        rms_norm_eps: float = 1e-06,\n        qkv_bias: bool = False,\n        rope_theta: float = 10000,\n        rope_scaling: tuple | None = None,\n    ) -> None:\n        super().__init__()\n        tp_size = dist.get_world_size()\n        self.total_num_heads = num_heads\n        assert self.total_num_heads % tp_size == 0\n        self.num_heads = self.total_num_heads // tp_size\n        self.total_num_kv_heads = num_kv_heads\n        assert self.total_num_kv_heads % tp_size == 0\n        self.num_kv_heads = self.total_num_kv_heads // tp_size\n        self.head_dim = head_dim or hidden_size // self.total_num_heads\n        self.q_size = self.num_heads * self.head_dim\n        self.kv_size = self.num_kv_heads * self.head_dim\n        self.scaling = self.head_dim ** -0.5\n        self.qkv_bias = qkv_bias\n\n        self.qkv_proj = QKVParallelLinear(\n            hidden_size,\n            self.head_dim,\n            self.total_num_heads,\n            self.total_num_kv_heads,\n            bias=qkv_bias,\n        )\n        self.o_proj = RowParallelLinear(\n            self.total_num_heads * self.head_dim,\n            hidden_size,\n            bias=False,\n        )\n        self.rotary_emb = get_rope(\n            self.head_dim,\n            rotary_dim=self.head_dim,\n            max_position=max_position,\n            base=rope_theta,\n            rope_scaling=rope_scaling,\n        )\n        self.attn = Attention(\n            self.num_heads,\n            self.head_dim,\n            self.scaling,\n            self.num_kv_heads,\n        )\n        if not self.qkv_bias:\n            self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)\n            self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n    ) -> torch.Tensor:\n        qkv = self.qkv_proj(hidden_states)\n        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)\n        q = q.view(-1, self.num_heads, self.head_dim)\n        k = k.view(-1, self.num_kv_heads, self.head_dim)\n        v = v.view(-1, self.num_kv_heads, self.head_dim)\n        if not self.qkv_bias:\n            q = self.q_norm(q)\n            k = self.k_norm(k)\n        q, k = self.rotary_emb(positions, q, k)\n        o = self.attn(q, k, v)\n        output = self.o_proj(o.flatten(1, -1))\n        return output\n\n\nclass Qwen3MLP(nn.Module):\n\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        hidden_act: str,\n    ) -> None:\n        super().__init__()\n        self.gate_up_proj = MergedColumnParallelLinear(\n            hidden_size,\n            [intermediate_size] * 2,\n            bias=False,\n        )\n        self.down_proj = RowParallelLinear(\n            intermediate_size,\n            hidden_size,\n            bias=False,\n        )\n        assert hidden_act == \"silu\"\n        self.act_fn = SiluAndMul()\n\n    def forward(self, x):\n        gate_up = self.gate_up_proj(x)\n        x = self.act_fn(gate_up)\n        x = self.down_proj(x)\n        return x\n\n\nclass Qwen3DecoderLayer(nn.Module):\n\n    def __init__(\n        self,\n        config: Qwen3Config,\n    ) -> None:\n        super().__init__()\n        self.self_attn = Qwen3Attention(\n            hidden_size=config.hidden_size,\n            num_heads=config.num_attention_heads,\n            num_kv_heads=config.num_key_value_heads,\n            max_position=config.max_position_embeddings,\n            rms_norm_eps=config.rms_norm_eps,\n            qkv_bias=getattr(config, 'attention_bias', True),\n            head_dim=getattr(config, 'head_dim', None),\n            rope_theta=getattr(config, \"rope_theta\", 1000000),\n            rope_scaling=getattr(config, \"rope_scaling\", None),\n        )\n        self.mlp = Qwen3MLP(\n            hidden_size=config.hidden_size,\n            intermediate_size=config.intermediate_size,\n            hidden_act=config.hidden_act,\n        )\n        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        positions: torch.Tensor,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor | None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        if residual is None:\n            hidden_states, residual = self.input_layernorm(hidden_states), hidden_states\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n        hidden_states = self.self_attn(positions, hidden_states)\n        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)\n        hidden_states = self.mlp(hidden_states)\n        return hidden_states, residual\n\n\nclass Qwen3Model(nn.Module):\n\n    def __init__(\n        self,\n        config: Qwen3Config,\n    ) -> None:\n        super().__init__()\n        self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])\n        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        positions: torch.Tensor,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n        residual = None\n        for layer in self.layers:\n            hidden_states, residual = layer(positions, hidden_states, residual)\n        hidden_states, _ = self.norm(hidden_states, residual)\n        return hidden_states\n\n\nclass Qwen3ForCausalLM(nn.Module):\n    packed_modules_mapping = {\n        \"q_proj\": (\"qkv_proj\", \"q\"),\n        \"k_proj\": (\"qkv_proj\", \"k\"),\n        \"v_proj\": (\"qkv_proj\", \"v\"),\n        \"gate_proj\": (\"gate_up_proj\", 0),\n        \"up_proj\": (\"gate_up_proj\", 1),\n    }\n\n    def __init__(\n        self,\n        config: Qwen3Config\n    ) -> None:\n        super().__init__()\n        self.model = Qwen3Model(config)\n        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)\n        if config.tie_word_embeddings:\n            self.lm_head.weight.data = self.model.embed_tokens.weight.data\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        positions: torch.Tensor,\n    ) -> torch.Tensor:\n        return self.model(input_ids, positions)\n\n    def compute_logits(\n        self,\n        hidden_states: torch.Tensor,\n    ) -> torch.Tensor:\n        return self.lm_head(hidden_states)\n"
  },
  {
    "path": "nanovllm/sampling_params.py",
    "content": "from dataclasses import dataclass\n\n\n@dataclass\nclass SamplingParams:\n    temperature: float = 1.0\n    max_tokens: int = 64\n    ignore_eos: bool = False\n\n    def __post_init__(self):\n        assert self.temperature > 1e-10, \"greedy sampling is not permitted\"\n"
  },
  {
    "path": "nanovllm/utils/context.py",
    "content": "from dataclasses import dataclass\nimport torch\n\n\n@dataclass\nclass Context:\n    is_prefill: bool = False\n    cu_seqlens_q: torch.Tensor | None = None\n    cu_seqlens_k: torch.Tensor | None = None\n    max_seqlen_q: int = 0\n    max_seqlen_k: int = 0\n    slot_mapping: torch.Tensor | None = None\n    context_lens: torch.Tensor | None = None\n    block_tables: torch.Tensor | None = None\n\n_CONTEXT = Context()\n\ndef get_context():\n    return _CONTEXT\n\ndef set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):\n    global _CONTEXT\n    _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)\n\ndef reset_context():\n    global _CONTEXT\n    _CONTEXT = Context()\n"
  },
  {
    "path": "nanovllm/utils/loader.py",
    "content": "import os\nfrom glob import glob\nimport torch\nfrom torch import nn\nfrom safetensors import safe_open\n\n\ndef default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):\n    param.data.copy_(loaded_weight)\n\n\ndef load_model(model: nn.Module, path: str):\n    packed_modules_mapping = getattr(model, \"packed_modules_mapping\", {})\n    for file in glob(os.path.join(path, \"*.safetensors\")):\n        with safe_open(file, \"pt\", \"cpu\") as f:\n            for weight_name in f.keys():\n                for k in packed_modules_mapping:\n                    if k in weight_name:\n                        v, shard_id = packed_modules_mapping[k]\n                        param_name = weight_name.replace(k, v)\n                        param = model.get_parameter(param_name)\n                        weight_loader = getattr(param, \"weight_loader\")\n                        weight_loader(param, f.get_tensor(weight_name), shard_id)\n                        break\n                else:\n                    param = model.get_parameter(weight_name)\n                    weight_loader = getattr(param, \"weight_loader\", default_weight_loader)\n                    weight_loader(param, f.get_tensor(weight_name))\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"nano-vllm\"\nversion = \"0.2.0\"\nauthors = [{ name = \"Xingkai Yu\" }]\nlicense = \"MIT\"\nlicense-files = [\"LICENSE\"]\nreadme = \"README.md\"\ndescription = \"a lightweight vLLM implementation built from scratch\"\nrequires-python = \">=3.10,<3.13\"\ndependencies = [\n    \"torch>=2.4.0\",\n    \"triton>=3.0.0\",\n    \"transformers>=4.51.0\",\n    \"flash-attn\",\n    \"xxhash\",\n]\n\n[project.urls]\nHomepage=\"https://github.com/GeeeekExplorer/nano-vllm\"\n\n[tool.setuptools.packages.find]\nwhere = [\".\"]\ninclude = [\"nanovllm*\"]\n"
  }
]